从tf.keras模型的中间层开始输入的方法

目的

想要分割一个模型,直接使用 tf.keras.Model(inputs=……, outputs=……)方法只能从模型本身设定的输入作为新模型的输入。

如果想要从中间层开始输入,封装一个新的模型需要比较麻烦的操作。

实现原理

手动设定一个输入,然后函数式的调用每一个layer,最后调用 tf.keras.Model(inputs=……, outputs=……)封装模型即可

例子

我使用的是tensorflow2.4。

以不带SE模块的mobilenetv3为例,我们需要保存每一层的输出和他们的名称,以解决res模块的add有两个输入的问题。

这里,我们从第一个add模块之后开始输入,具体代码如下:

import tensorflow as tf
# 原模型
mobilenetv3_model = tf.keras.applications.MobileNetV3Large(include_top=False, alpha=1.0, minimalistic=True, input_shape=None, weights=weights)
converter = tf.lite.TFLiteConverter.from_keras_model(mobilenetv3_model)
tflite_model = converter.convert()
open("./org_model.tflite", "wb").write(tflite_model)

# 新模型
test_model = split_model_by_layer_name(mobilenetv3_model, 'expanded_conv_1/expand')
converter = tf.lite.TFLiteConverter.from_keras_model(test_model)
tflite_model = converter.convert()
open("./new_model.tflite", "wb").write(tflite_model)

def split_model_by_layer_name(model, name):
    layers = model.layers
    layers_output_saver = dict()
    input = tf.keras.layers.Input(shape=(224, 224, model.get_layer(name).input.shape[-1]))
    output = None
    start = False
    i = 0
    
    for layer in layers:
        if layer.name == name:
            start = True
        if not start:
            continue
        layer_input = []
        if isinstance(layer.input, list) and len(layer.input)>1:  # 对于有多输入的节点
            for input_temp in layer.input:
                layer_input.append(layers_output_saver.get(input_temp.name))
        else:
            if i == 0:
                layer_input = input
            else:
                layer_input = output
        output = layer(layer_input)
        layers_output_saver[layer.output.name] = output
        i += 1

    new_model = tf.keras.Model(inputs=input, outputs=output)
    return new_model

可以看到保存下来的模型与预想是相同的。

结论

实现原理是非常简单的,只要按模型的结构再调用一次原模型的layer即可。不同的模型实现方法会不一样,掌握原理最重要。

猜你喜欢

转载自blog.csdn.net/qq_19313495/article/details/113885848