tf.map_fn( )的用法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u012193416/article/details/86565633
map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
           swap_memory=False, infer_shape=True, name=None)

其中 fn 是一个可调用函数,可以使用 lambda 来表示,elems 是需要处理的 tensors, tf 将会从第一维开始展开,进行 map 操作,dtype 表示 fn 函数的输出类型,如果 fn 返回的类型和 elems 中的不同,那么就必须显示指定为和 fn 返回类型相同的类型。

可以看出 map_fn 是一个反复将可调用函数fn应用于 elems 元素序列的一个高阶函数。

有很多用处

在处理图片时,是(batch_size,height,width,depth),batch_size是一次处理的多少,一个 batch 内同样对图片进行处理,对视频进行卷积操作时,视频输入是(batch_size,frames,height,width,depth),其中多了个 frames 帧数,肯定是不能对视频进行卷积的,视频的每个切片产生后,我们同样是对每一帧进行卷积,所以采用map_fn 函数,对每个切片应用卷积操作,每个batch 之间没有关联,可以并行快速的处理。

tf.map_fn(fn=lambda x:tf.nn.conv2d(x,kernel,stride,padding='same'),elems=batch,dtype=tf.float32)

猜你喜欢

转载自blog.csdn.net/u012193416/article/details/86565633