keras 一维残差网络简单实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/baidu_36161077/article/details/81388332

pytorch 版本的实现请看这里

import keras
from keras.models import Model,Input
from keras.layers import Conv2D,Reshape,BatchNormalization,Activation,Flatten,Merge,Dense
import functools
import pandas as pd
import pickle 
from sys import path
path.append('/model')
import numpy as np
from keras.callbacks import ModelCheckpoint



def residual_block(filters,x,stride = 1):
    resiual = x
    out = BatchNormalization()(x)
    out1 = Activation('relu')(out)
    out = Conv2D(filters = int(filters / 4),kernel_size = [1,1],strides = [1,1],padding = 'same')(out1)
    out = BatchNormalization()(out)
    out = Activation('relu')(out)
    out = Conv2D(filters = int(filters / 4),kernel_size = [3,1],strides = [1,1],padding = 'same')(out)
    out = BatchNormalization()(out)
    out = Activation('relu')(out)
    out = Conv2D(filters = filters,kernel_size = [1,1],strides = [1,1],padding = 'same')(out)
    if out.shape[-1] != filters or stride == 1:
        residual = Conv2D(filters = filters,kernel_size = [3,1],strides = [1,1],padding = 'same')(out1)
    out = Merge(mode = 'sum')([residual,out])
    return out


x = Input(shape = [sequence_len,1,1])
conv1 = Conv2D(filters = 30,kernel_size = [5,1],strides = [1,1],padding = 'same')(x)
bn = BatchNormalization()(conv1)
out = Activation('relu')(bn)
residual_block1 = residual_block(filters = 30,x = out)
residual_block2 = residual_block(filters = 40,x = residual_block1)
residual_block3 = residual_block(filters = 50,x = residual_block2)
out = Flatten()(residual_block3)
out = Dense(units = sequence_len)(out)
model = Model(x,out)
model.compile(optimizer = 'adam',loss = 'mse')


猜你喜欢

转载自blog.csdn.net/baidu_36161077/article/details/81388332
今日推荐