【TensorFlow系列】【九】利用tf.py_func自定义算子

本文讲述如下问题:

1.如何定义list类型的placeholder?

2.如何将普通python函数包装成TensorFlow算子,加入到NN网络中?

具体见代码:

import tensorflow as tf
import numpy as np

def gen_tfrecords():
    with tf.python_io.TFRecordWriter(r"D:\my.tfrecords") as tf_writer:
        features = {}
        features['scale'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[15]))
        xmin = []
        xmax = []
        ymin = []
        ymax = []
        for i in range(2):
            xmin.append(float(i))
            xmax.append(float(i+500))
            ymin.append(float(i))
            ymax.append(float(i+500))
        # 变长数据以list形式存储
        features['xmin'] = tf.train.Feature(float_list=tf.train.FloatList(value=xmin))
        features['xmax'] = tf.train.Feature(float_list=tf.train.FloatList(value=xmax))
        features['ymin'] = tf.train.Feature(float_list=tf.train.FloatList(value=ymin))
        features['ymax'] = tf.train.Feature(float_list=tf.train.FloatList(value=ymax))
        tf_features = tf.train.Features(feature=features)
        tf_example = tf.train.Example(features=tf_features)
        tf_serialized = tf_example.SerializeToString()
        tf_writer.write(tf_serialized)
gen_tfrecords()
def parse_tf(example_proto):
    dics = {}
    #定长数据解析
    dics['scale'] = tf.FixedLenFeature(shape=[], dtype=tf.int64)
    #列表数据解析
    dics['xmin'] = tf.VarLenFeature(tf.float32)
    dics['xmax'] = tf.VarLenFeature(tf.float32)
    dics['ymin'] = tf.VarLenFeature(tf.float32)
    dics['ymax'] = tf.VarLenFeature(tf.float32)
    parse_example = tf.parse_single_example(serialized=example_proto,features=dics)
    xmin = parse_example['xmin']
    xmax = parse_example['xmax']
    ymin = parse_example['ymin']
    ymax = parse_example['ymax']
    scale = parse_example['scale']

    return scale,xmin,xmax,ymin,ymax

def scale_image(scale):
    w = 10
    h = 10
    w = w*scale
    h = h*scale
    return w,h
def scale_image2(scale):
    w = 10
    h = 10
    w = w*scale
    h = h*scale
    lst = [w,h]
    #如果想要返回一个list,需要将其封装为一个ndarray
    return np.array(lst)
def image_s(scale):
    w = 10
    h = 10
    w = w*scale
    h = h*scale
    return w*h

def calc_image_s(xmin,xmax,ymin,ymax):
    ss = []
    for i in range(len(xmin)):
        s = (xmax[i]-xmin[i])*(ymax[i]-ymin[i])
        ss.append(s)
    return np.array(ss)

scale_p = tf.placeholder(dtype=tf.int64)
#如果placeholder的shape不写,则可表示各种类型的数据,这里可用于表示list类型的数据
x_min_p = tf.placeholder(dtype=tf.float32)
x_max_p = tf.placeholder(dtype=tf.float32)
y_min_p = tf.placeholder(dtype=tf.float32)
y_max_p = tf.placeholder(dtype=tf.float32)

#tf.py_func用来将不同python函数包裹成TensorFlow算子,返回值是tensor,Tout表示函数返回值的类型,单个返回值不用[],多个返回值,要用[]
nw,nh = tf.py_func(scale_image,inp=[scale_p],Tout=[tf.int64,tf.int64])
nw_nh = tf.py_func(scale_image2,inp=[scale_p],Tout=tf.int64)
s = tf.py_func(image_s,inp=[scale_p],Tout=tf.int64)
ss= tf.py_func(calc_image_s,inp=[x_min_p,x_max_p,y_min_p,y_max_p],Tout=tf.float32)

two = tf.constant(value=2,dtype=tf.float32)
s2 = tf.multiply(ss,two)

dataset = tf.data.TFRecordDataset(r"D:\my.tfrecords")
dataset = dataset.map(parse_tf).batch(1).repeat(1)

iterator = dataset.make_one_shot_iterator()

next_element = iterator.get_next()
with tf.Session() as session:
    scale, xmin, xmax, ymin, ymax = session.run(fetches=next_element)
    w,h = session.run(fetches=[nw,nh],feed_dict={scale_p:scale})
    print(w,h)

    w_h = session.run(fetches=[nw_nh], feed_dict={scale_p: scale})
    print(w_h)

    s1 = session.run(fetches=[s], feed_dict={scale_p: scale})
    print(s1)

    s1 = session.run(fetches=[ss], feed_dict={x_min_p:xmin.values,x_max_p:xmax.values,y_min_p:ymin.values,y_max_p:ymax.values})
    print(s1)

    s22 = session.run(fetches=[s2], feed_dict={x_min_p: xmin.values, x_max_p: xmax.values, y_min_p: ymin.values, y_max_p: ymax.values})
    print(s22)

结果如下:

[150] [150]
[array([[150],
       [150]], dtype=int64)]
[array([22500], dtype=int64)]
[array([250000., 250000.], dtype=float32)]
[array([500000., 500000.], dtype=float32)]

定义新的op时会用到该方法,据官网介绍,这种做法,目前不支持分布式与模型保存。但是对于辅助op,基本上够用了,例如faster r-cnn中anchor的生成与RPN训练时label的生成。

tf.py_func 要求包裹的函数,输入输出均为ndarray

猜你喜欢

转载自my.oschina.net/u/3800567/blog/1794223