最近,我在训练TransUNet网络时遇到了一些挑战,特别是在处理高版本Labelme标注的数据集和三通道输入时,数据集转换失败。为了记录这一过程并分享解决方案,我将详细介绍我在处理高版本标注和TransUNet三通道数据集时的经验和方法。
首先检查数据集是否完整,验证图片与标签是否一一对应
import os
import json
def check_and_cleanup(image_dir, json_dir):
"""检查PNG图片和JSON标签是否一一对应,并删除不匹配的文件"""
# 获取所有PNG图片和JSON文件的基本文件名(不包括扩展名)
image_files = {os.path.splitext(f)[0] for f in os.listdir(image_dir) if f.endswith('.png')}
json_files = {os.path.splitext(f)[0] for f in os.listdir(json_dir) if f.endswith('.json')}
# 找到不匹配的文件
missing_images = json_files - image_files
missing_jsons = image_files - json_files
# 输出结果并删除不匹配的文件
if not missing_images and not missing_jsons:
print("所有PNG图片和JSON标签一一对应。")
else:
if missing_images:
print("缺少对应的PNG图片文件,删除对应的JSON文件:")
for file in missing_images:
json_file_path = os.path.join(json_dir, f"{file}.json")
if os.path.exists(json_file_path):
os.remove(json_file_path)
print(f"已删除 JSON 文件: {json_file_path}")
if missing_jsons:
print("缺少对应的JSON文件,删除对应的PNG文件:")
for file in missing_jsons:
png_file_path = os.path.join(image_dir, f"{file}.png")
if os.path.exists(png_file_path):
os.remove(png_file_path)
print(f"已删除 PNG 文件: {png_file_path}")
def main():
image_dir = r'F:\earn\TransUNet-main\images' # 修改为您的PNG图片目录
json_dir = r'F:\earn\TransUNet-main\json' # 修改为您的JSON文件目录
check_and_cleanup(image_dir, json_dir)
if __name__ == "__main__":
main()
并将json文件转换为mask掩码图,只需要修改图片输入路径和输出路径,还有对应的标签。
from __future__ import division
import io
import os
import base64
import json
import uuid
import cv2
import os.path as osp
import numpy as np
import PIL.Image
from PIL import Image
import PIL.Image
import PIL.ImageDraw
def img_b64_to_arr(img_b64):
"""
base64转array
"""
imgdata = base64.b64decode(img_b64)
image = io.BytesIO(imgdata)
img = Image.open(image)
res = np.asarray(img)
return res
def shape_to_mask(img_shape, points, shape_type=None, line_width=10, point_size=5):
mask = np.zeros(img_shape[:2], dtype=np.uint8)
mask = PIL.Image.fromarray(mask)
draw = PIL.ImageDraw.Draw(mask)
xy = [tuple(point) for point in points]
if shape_type == "circle":
assert len(xy) == 2, "Shape of shape_type=circle must have 2 points"
(cx, cy), (px, py) = xy
d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
elif shape_type == "rectangle":
assert len(xy) == 2, "Shape of shape_type=rectangle must have 2 points"
draw.rectangle(xy, outline=1, fill=1)
elif shape_type == "line":
assert len(xy) == 2, "Shape of shape_type=line must have 2 points"
draw.line(xy=xy, fill=1, width=line_width)
elif shape_type == "linestrip":
draw.line(xy=xy, fill=1, width=line_width)
elif shape_type == "point":
assert len(xy) == 1, "Shape of shape_type=point must have 1 points"
cx, cy = xy[0]
r = point_size
draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1)
else:
assert len(xy) > 2, "Polygon must have points more than 2"
draw.polygon(xy=xy, outline=1, fill=1)
mask = np.array(mask, dtype=bool)
return mask
def shapes_to_label(img_shape, shapes, label_name_to_value, type="class"):
assert type in ["class", "instance"]
cls = np.zeros(img_shape[:2], dtype=np.int32)
if type == "instance":
ins = np.zeros(img_shape[:2], dtype=np.int32)
instance_names = ["_background_"]
for shape in shapes:
points = shape["points"]
label = shape["label"]
shape_type = shape.get("shape_type", None)
if type == "class":
cls_name = label
elif type == "instance":
cls_name = label.split("-")[0]
if label not in instance_names:
instance_names.append(label)
ins_id = instance_names.index(label)
cls_id = label_name_to_value[cls_name]
mask = shape_to_mask(img_shape[:2], points, shape_type)
cls[mask] = cls_id
if type == "instance":
ins[mask] = ins_id
if type == "instance":
return cls, ins
return cls
def label_colormap(N=256):
def bitget(byteval, idx):
return (byteval & (1 << idx)) != 0
cmap = np.zeros((N, 3))
for i in range(0, N):
id = i
r, g, b = 0, 0, 0
for j in range(0, 8):
r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
id = id >> 3
cmap[i, 0] = r
cmap[i, 1] = g
cmap[i, 2] = b
cmap = cmap.astype(np.float32) / 255
return cmap
def lblsave(filename, lbl):
if osp.splitext(filename)[1] != ".png":
filename += ".png"
if lbl.min() >= -1 and lbl.max() < 255:
lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode="P")
colormap = label_colormap(255)
lbl_pil.putpalette((colormap * 255).astype(np.uint8).flatten())
lbl_pil.save(filename)
else:
raise ValueError(
"[%s] Cannot save the pixel-wise class label as PNG. "
"Please consider using the .npy format." % filename
)
if __name__ == "__main__":
jpgs_path = "F:\earn\TransUNet-main\data4\JPG"#文件路径
pngs_path = "F:\earn\TransUNet-main\data4\PNG"
classes = ["_background_","LM", "IL", "EC", "LungM", "OtherMc", "BM"]#自己数据集标签
count = os.listdir("F:\earn\TransUNet-main\json")
for i in range(0, len(count)):
path = os.path.join("F:\earn\TransUNet-main\json", count[i])
if os.path.isfile(path) and path.endswith("json"):
data = json.load(open(path, encoding='utf-8'))
if data["imageData"]:
imageData = data["imageData"]
else:
imagePath = os.path.join(os.path.dirname(path), data["imagePath"])
with open(imagePath, "rb") as f:
imageData = f.read()
imageData = base64.b64encode(imageData).decode("utf-8")
img = img_b64_to_arr(imageData)
label_name_to_value = {"_background_": 0}
for shape in data["shapes"]:
label_name = shape["label"]
if label_name in label_name_to_value:
label_value = label_name_to_value[label_name]
else:
label_value = len(label_name_to_value)
label_name_to_value[label_name] = label_value
label_values, label_names = [], []
for ln, lv in sorted(label_name_to_value.items(), key=lambda x: x[1]):
label_values.append(lv)
label_names.append(ln)
assert label_values == list(range(len(label_values)))
lbl = shapes_to_label(img.shape, data["shapes"], label_name_to_value)
PIL.Image.fromarray(img).save(
osp.join(jpgs_path, count[i].split(".")[0] + ".jpg")
)
new = np.zeros([np.shape(img)[0], np.shape(img)[1]])
print("new_shape: ", new.shape)
print("label_names: ", label_names)
for name in label_names:
index_json = label_names.index(name)
index_all = classes.index(name)
print("index_all: ", index_all)
new = new + index_all * (np.array(lbl) == index_json)
print("new_shape: ", new.shape)
lblsave(osp.join(pngs_path, count[i].split(".")[0] + ".png"), new)
print(
"Saved "
+ count[i].split(".")[0]
+ ".jpg and "
+ count[i].split(".")[0]
+ ".png"
)
这样就将三通道问题和高版本的labelme问题解决,不仅可以使用再Transunet中其他语义分割中由于高版本标注问题都已解决。
如果您对Transuent模型的改进和深度学习技术感兴趣,欢迎关注我的公众号 "AI代码 Insights"。在这里,我会定期分享最新的人工智能技术、深度学习算法和实践经验,与大家共同探讨AI领域的前沿动态。同时需要实现代码的可以通过公众号来找我要。