FCN结构in Speech的Keras简单实现

FCN结构in Speech的Keras简单实现

本文通过Keras简单实现了一种FCN结构in Speech。Github有其他的实现代码,本文仅是通过自己的理解,对参考文献[1]中的网络进行搭建。若对其它代码有兴趣,请移步Github[2] (不清楚是否为论文作者创作的源代码)。
参考文献:

[1] Z. Ouyang, H. Yu, W. Zhu and B. Champagne, “A Fully Convolutional Neural Network for Complex Spectrogram Processing in Speech Enhancement,” ICASSP 2019 - 2019 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Brighton, United Kingdom, 2019, pp. 5756-5760, doi: 10.1109/ICASSP.2019.8683423.

[2] https://github.com/phpstorm1/SE-FCN

上代码

    // FCN
    def FCN_ICASSP(self):
        frame_number_one_sample = self.frame_number_one_sample 

        train_input_logstft = self.train_input

        train_output_target = self.train_output

    // 根据需要切割验证集 代码省略

        // 记录epoch时间函数
        class TimeHistory(keras.callbacks.Callback):
            def on_train_begin(self, logs={
    
    }):
                self.times = []
                self.totaltime = time.time()

            def on_train_end(self, logs={
    
    }):
                self.totaltime = time.time() - self.totaltime

            def on_epoch_begin(self, batch, logs={
    
    }):
                self.epoch_time_start = time.time()

            def on_epoch_end(self, batch, logs={
    
    }):
                self.times.append(time.time() - self.epoch_time_start)

        // x根据需要设置。设置依据:STFT后的Frequency bin
        input = Input(shape=(x, frame_number_one_sample, 1))

        // conv2d layers. 利用Conv2d实现Conv1d的功能(也可直接替换成Keras中的Conv1d)
        x1_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(input)
        x1 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(1, 1), padding='same')(input)
        x1_act = layers.Activation('relu')(x1)
        x1_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x1_act)
        x1_res_layers = layers.add([x1_1d_skip, x1_res])

        x2_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x1_res_layers)
        x2 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(2, 1), padding='same')(x1_res_layers)
        x2_act = layers.Activation('relu')(x2)
        x2_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x2_act)
        x2_res_layers = layers.add([x2_1d_skip, x2_res])

        x3_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x2_res_layers)
        x3 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(4, 1), padding='same')(x2_res_layers)
        x3_act = layers.Activation('relu')(x3)
        x3_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x3_act)
        x3_res_layers = layers.add([x3_1d_skip, x3_res])

        x4_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x3_res_layers)
        x4 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(8, 1), padding='same')(x3_res_layers)
        x4_act = layers.Activation('relu')(x4)
        x4_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x4_act)
        x4_res_layers = layers.add([x4_1d_skip, x4_res])

        x5_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x4_res_layers)
        x5 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(16, 1), padding='same')(x4_res_layers)
        x5_act = layers.Activation('relu')(x5)
        x5_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x5_act)
        x5_res_layers = layers.add([x5_1d_skip, x5_res])

        x6_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x5_res_layers)
        x6 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(32, 1), padding='same')(x5_res_layers)
        x6_act = layers.Activation('relu')(x6)
        x6_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x6_act)
        x6_res_layers = layers.add([x6_1d_skip, x6_res])

        skip_connection = layers.add([x1_1d_skip, x2_1d_skip, x3_1d_skip, x4_1d_skip, x5_1d_skip, x6_1d_skip])

        // 对feature map进行切片
        def slice(x, index):
            return x[:,:,index,:]

        slice_layers = layers.Lambda(slice, output_shape=(x, 1, 48), arguments={
    
    'index':6})(skip_connection)
        reshape_layers2 = layers.Reshape((x, 1, 48))(slice_layers)


        // conv1d layers
        x7_1d_skip = layers.Conv2D(96, (3, 1), strides=(1, 1), dilation_rate=(1, 1), padding='same')(reshape_layers2)
        x7_act = layers.Activation('relu')(x7_1d_skip)

        x8_1d_skip = layers.Conv2D(1, (3, 1), activation='sigmoid',strides=(1, 1), dilation_rate=(1, 1), padding='same')(x7_act)


        model = Model(input, x8_1d_skip)


        model.compile(optimizer = 'adam',
                      loss='binary_crossentropy')

        model.summary()


        time_callback = TimeHistory()
        // epoch 和 batch根据情况改动
        model.fit(partial_train_input_logstft,
                  partial_train_output_targetsnr,
                  epochs = 100,
                  batch_size = 96,               
                  callbacks=[time_callback],
                  validation_data = (val_train_input_logstft,val_train_output_targetsnr)
                  )
        print(time_callback.times)
        print(time_callback.totaltime)

        model.save('FCN_ICASSP_model.h5')

        print('model have train all ready')

猜你喜欢

转载自blog.csdn.net/qq_40550384/article/details/106729506
FCN