keras自定义网络层_(源码解读)

版权声明:转载请务必注明出处并注明“武汉AI算法研习” https://blog.csdn.net/qq_36931982/article/details/90369187

keras是基于Tensorflow等的一个神经网络的上层框架,通过Keras我们可以简单的构造出自己的神经网络,同时Keras针对主流的算法框架、激活函数和优化函数等进行自己的实现,某些方面只需要我们进行简单的调用,Keras的出现大大简化了网络构建的成本。

Keras自定义网络层需要一下步骤:

1、继承一个Layer

keras顶级Layer类定义在engine包的base_layer.py文件,其中的class Layer(object)类定义了基本的关于Layer的方法和变量:这定义的方法变量众多,随着Keras版本的更新越来越满足我们特殊的需求。

class Layer(object):
#抽象类主要方法总结

  #为layer增加权重
  def add_weight()
  
  #对输入数据和本Layer定义的数据进行效验,如输入数据不符合定义规定报错
  def assert_input_compatibility(self, inputs):

  #训练过程中前向和后向传播的主要逻辑实现
  def call()

  #计算本层的输出形状
  def compute_output_shape()

  #计算输出的掩膜
  def compute_mask()

  #构造layer的权重
  def build()

  #检索给定节点上的层的输入\出形状
  def get_input_shape_at()
  def get_output_shape_at()

  #检索给定节点上的层的输入/出张量
  def get_input_at():
  def get_output_at():

  #检索给定节点上的层的输入/出掩模张量。
  def get_input_mask_at(self, node_index):
  def get_output_mask_at(self, node_index):

以Keras里面全连接层为例,Dense继承自Layer类,主要实现了build()、call()、compute_out_shaper()和get_config()方法。

class Dense(Layer):
  def build(self, input_shape):
  def call(self, inputs):
  def compute_output_shape(self, input_shape):
  def get_config(self):
  ...

2、重写其中的方法

keras自定义网络layer,主要根据自己网络的需要,对其父类的部分方法进行重写,当然如果有特殊的需要也可以对其父类集合Tensorflow进行改写。

build(input_shape):

通过build()方法定义自己layer的权重,此方法最后必须实现变量self.build = True。实现过程中我们可以对权重值进行约束和初始化或者正则化,分别调用self.kernel_constraint、self.kernel_initializer和self.kernel_regularizer可以实现。

call(x):

通过call()方法进行功能逻辑实现,是该层的计算逻辑或计算图。显然,这个层的核心应该是一段符号式的输入张量到输出张量的计算过程。

class BinaryConv2D(Conv2D):

    #定义构造方法
    def __init__(self, filters, kernel_lr_multiplier='Glorot', bias_lr_multiplier=None, H=1., **kwargs):
        super(BinaryConv2D, self).__init__(filters, **kwargs)
        self.H = H
        self.kernel_lr_multiplier = kernel_lr_multiplier
        self.bias_lr_multiplier = bias_lr_multiplier
        
    #这是你定义权重的地方。input_shape=(28,28,1)
    def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1 
        if input_shape[channel_axis] is None:
                raise ValueError('The channel dimension of the inputs '
                                 'should be defined. Found `None`.')

        input_dim = input_shape[channel_axis] #获取channel
        kernel_shape = self.kernel_size + (input_dim, self.filters) #*(3,3,1,128)
            
        base = self.kernel_size[0] * self.kernel_size[1] #9
        if self.H == 'Glorot':
            nb_input = int(input_dim * base)
            nb_output = int(self.filters * base)
            self.H = np.float32(np.sqrt(1.5 / (nb_input + nb_output)))
            #print('Glorot H: {}'.format(self.H))
            
        if self.kernel_lr_multiplier == 'Glorot':
            nb_input = int(input_dim * base) #9
            nb_output = int(self.filters * base) #1152
            self.kernel_lr_multiplier = np.float32(1. / np.sqrt(1.5/ (nb_input + nb_output)))
            #print('Glorot learning rate multiplier: {}'.format(self.lr_multiplier))

        self.kernel_constraint = Clip(-self.H, self.H) #对主权重矩阵进行约束
        self.kernel_initializer = initializers.RandomUniform(-self.H, self.H)
        self.kernel = self.add_weight(shape=kernel_shape,
                                 initializer=self.kernel_initializer,
                                 name='kernel',
                                 regularizer=self.kernel_regularizer,
                                 constraint=self.kernel_constraint)

        if self.use_bias:
            self.lr_multipliers = [self.kernel_lr_multiplier, self.bias_lr_multiplier]
            self.bias = self.add_weight((self.output_dim,),
                                     initializer=self.bias_initializers,
                                     name='bias',
                                     regularizer=self.bias_regularizer,
                                     constraint=self.bias_constraint)

        else:
            self.lr_multipliers = [self.kernel_lr_multiplier]
            self.bias = None

        # Set input spec.
        self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
        self.built = True

    #这里是编写层的功能逻辑的地方。
    def call(self, inputs):
        binary_kernel = binarize(self.kernel, H=self.H) 
        outputs = K.conv2d(
            inputs,
            binary_kernel,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate)

        if self.use_bias:
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs

      def get_config(self):
        config = {'H': self.H,
                  'kernel_lr_multiplier': self.kernel_lr_multiplier,
                  'bias_lr_multiplier': self.bias_lr_multiplier}
        base_config = super(BinaryConv2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
  

猜你喜欢

转载自blog.csdn.net/qq_36931982/article/details/90369187