读取Cifar10数据集中的数据
file = 'data\cifar-10-batches-py\data_batch_1'
def unpickle(file):
fo = open(file, 'rb')
dict = pickle.load(fo, encoding='latin1')
fo.close()
return dict
# 第几张图片
line_number = 0
# 显示测试集图片
dict = unpickle(file)
data = dict.get("data")
label = dict.get("labels")
# 由于在cifar10的data中,图片数据存储为了一个一维数组的形式
# 但其本身是一个rgb三通道的32x32图片,因此我们将其整型如下,并存入r,g,b中
image_m = np.reshape(data[line_number], (3, 32, 32))
image_label = label[line_number]
r = image_m[0, :, :]
g = image_m[1, :, :]
b = image_m[2, :, :]
# 通过将r,g,b合并来输出原图
img32 = np.array(cv.merge([r, g, b]))
plt.figure()
plt.imshow(img32)
plt.show()
制作想要写入的数据
在这里我只做了一个图片 img32_compress,并且也是一个三通道的图片,并且将每个通道转换为一维数组并存入temp,因为在cifar10中,图片是一维存储的
temp_r = np.reshape(img32_compress[:, :, 0], (1024, )).tolist()
temp_g = np.reshape(img32_compress[:, :, 1], (1024, )).tolist()
temp_b = np.reshape(img32_compress[:, :, 2], (1024, )).tolist()
然后我们将这些一维数组存放到想要修改的图片的对应位置,args.line_number是指cifar10数据集字典中data的行数,在data中,每一行代表一张图片,第几行就是第几张图片
dict.get("data")[args.line_number,0:1024] = temp_r
dict.get("data")[args.line_number,1024:2048] = temp_g
dict.get("data")[args.line_number,2048:3072] = temp_b
最后在保存成二进制文件时,np.array这个type是不被支持的,因此我们需要使用 .tolist() 把data从 np.array 改为 list
最后使用pickle.dump将我们制作的新图片dict写入原本的文件f1
dict['data'] = dict['data'].tolist()
f1 = open(file, 'wb+')
pickle.dump(dict, f1)
# f1.write(json.dumps(dict).encode())
f1.close()