Tensorpack.MultiProcessPrefetchData改进,实现高效的数据流水线

参考代码:https://github.com/tensorpack/tensorpack/blob/master/tensorpack/dataflow/parallel.py(目前最新版本已经更名为MultiProcessRunner,在最早的版本叫做MultiProcessPrefetchData)

  Tensorpack的数据流水线有多个,其中一个比较好实现的是MultiProcessRunner这个类,思路很简单,利用multiprocess.Queue队列,启动若干线程向队列push元素,然后在__iter__方法中从队列中拿元素.这个MultiProcessRunner处理方式,也在cifar10_resnet中被使用:https://github.com/tensorpack/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py。但是假若图片的尺寸变大(cifar10尺寸是32*32,一个batch取128,图像总的数据量为:128*32*32*3*4bytes=1.5M,imagenet数据大小为224,batch取32的话,数据量为:32*224*224*3*4bytes=18.375M)受限于python多进程队列的实现方式(pipeline),取数据会变得非常慢,从0.02ms变为大概20几ms(通常数据越大网络运行一次的时间越大,相比之下获取数据的时间就基本上可以忽略不计).Tensorpack中imagenet的例子使用的MultiProcessRunnerZMQ,使用ZMQ替换Queue实现跨进程传输.

  其实只需要在MultiProcessRunner的基础上稍加改进,就能实现同样0.2ms的数据流水线功能,参考代码:https://github.com/WeiTang114/FMQ,即fast MultiProcess.Queue,原理也很简单:利用python的Queue模块,这个队列不同于multiprocess.Queue,属于本地队列,不用跨进程传输。因此即使是取很大的数据,时间也很短.所以可以再开一个队列和线程,不断的从进程队列中拿元素然后再放入本地队列中,__iter__中直接从本地队列拿元素.

猜你喜欢

转载自www.cnblogs.com/deepllz/p/11433252.html