利用dill对Dataset/DataLoader进行保存与加载

原文:https://mp.weixin.qq.com/s/TStrHMbgDjPIsXaPK6-VSQ

1 dill 简介

pickle/dill 可以用于保存对象等大多数Python的数据格式; 其中pickle不可以保存lambda函数,序列化对象等,但dill可以保存。

pickle 和 dill 的用法一样。

dill 特点:

1) 可以pickle以下标准类型

  • none, type, bool, int, long, float, complex, str, unicode,
  • tuple, list, dict, file, buffer, builtin,
  • both old and new style classes,
  • instances of old and new style classes,
  • set, frozenset, array, functions, exceptions

2) 也可以pickle一些独特的类型

  • functions with yields, nested functions, lambdas,
  • cell, method, unboundmethod, module, code, methodwrapper,
  • dictproxy, methoddescriptor, getsetdescriptor, memberdescriptor,
  • wrapperdescriptor, xrange, slice,
  • notimplemented, ellipsis, quit

3) 但以下类型暂时不可以pickle:

  • frame, generator, traceback

4) dill的其他作用

  • save and load python interpreter sessions
  • save and extract the source code from functions and classes
  • interactively diagnose pickling errors

接下来,介绍 dill 的几种用法。

2 保存匿名函数

# !pip install dill
import dill

# 保存匿名函数
squared = lambda x: x**2
dill.loads(dill.dumps(squared))(3)

# 9

3 查看源码

# 保存源码
import dill.source
print(dill.source.getsource(squared))

在 ipython会报OSError: could not extract source code, 改成如下形式即可:

code=dill.source.getsource(dill.detect.code(squared))
print(code)

# squared = lambda x: x**2

4 保存类Dataset, DataLoader

这样可以加速数据预处理。

# 保存class
from torch.utils.data import TensorDataset, DataLoader
import torch
from sklearn.datasets import make_classification

data, target = make_classification()
# data.shape, target.shape
# ((100, 20), (100,))

batch_size=10
dataset = TensorDataset(torch.from_numpy(data), torch.from_numpy(target))
dataloader = DataLoader(dataset, shuffle=False, drop_last=True, batch_size=batch_size)

保存数据:

dill.dump(dataset, './dataset_save.pkl')
dill.dump(dataloader, './dataloader_save.pkl')

直接保存会报错:TypeError: file must have a 'write' attribute

采用下述方式保存即可:

with open('./dataset_save.pkl','wb') as f:
    dill.dump(dataset, f)

with open('./dataloader_save.pkl','wb') as f:
    dill.dump(dataloader, f)

加载数据:

with open('./dataset_save.pkl','rb') as f:
    dataset_save = dill.load(f)

with open('./dataloader_save.pkl','rb') as f:
    dataloader_save = dill.load(f)

数据比较:

x, y = next(iter(dataloader))
x_save, y_save = next(iter(dataloader_save))
torch.equal(x, x_save), torch.equal(y, y_save)
# (True, True)

猜你喜欢

转载自blog.csdn.net/mengjizhiyou/article/details/127251705