Input tensors to a Model must come from `keras.layers.Input`

Keras 遇到Input tensors to a Model must come from keras.layers.Input或者Output tensors to a Model must be the output of a TensorFlow Layer

1.Input tensors to a Model must come from keras.layers.Input

问题描述:
在搭建神经网络模型时,如下:

def SE_Inception_resnet_v2(input_x=None,input_shape=(256, 256, 3)):

    if K.backend() != 'tensorflow':
        raise RuntimeError('The Deeplabv3+ model is only available with '
                           'the TensorFlow backend.')
    if input_x is None:
        img_input = Input(shape=input_shape)
    else:
        if not K.is_keras_tensor(input_x):
            img_input = Input(tensor=input_x,shape=input_shape)
        else:
            img_input = input_x
            print('ftygjdcgh')
     x = Stem(img_input, scope='stem')
       。。。。。。
       。。。。。。
       。。。。。。
       
     model = Model(inputs=input_x, outputs=x)

原因是model = Model(inputs=input_x, outputs=x)中的inputs不符合keras的输入要求,应当进行转换如下:

def SE_Inception_resnet_v2(input_x=None,input_shape=(256, 256, 3)):

    if K.backend() != 'tensorflow':
        raise RuntimeError('The Deeplabv3+ model is only available with '
                           'the TensorFlow backend.')
    if input_x is None:
        img_input = Input(shape=input_shape)
    else:
        if not K.is_keras_tensor(input_x):
            img_input = Input(tensor=input_x,shape=input_shape)
        else:
            img_input = input_x
            print('ftygjdcgh')
     x = Stem(img_input, scope='stem')
       。。。。。。
       。。。。。。
       。。。。。。
    if input_x is not None:
        inputs = get_source_inputs(input_x)
    else:
        inputs = img_input

    model = Model(inputs=inputs, outputs=x)

需要导入包:from keras.engine.topology import get_source_inputs即可解决。

2.Output tensors to a Model must be the output of a TensorFlow Layer

这个是由于网络模型中包含有由tensorflow中API组成的网络层,如tf.reshape、tf.concatenate等等。
解决方法是用Lamda层将函数包装成keras层
例如:

def swish(x):
    return (tf.nn.swish(x))
      
 x=Conv2D(self.in_ch, (1, 1),strides=1, padding='same',use_bias=False, name=self.scope + '_pointwise',kernel_regularizer=weight_decay)(x)
 x=Lambda(swish)(x)

猜你喜欢

转载自blog.csdn.net/weixin_45582028/article/details/118251017