淘先锋技术网

首页 1 2 3 4 5 6 7

Keras自定义或者重写层,需要实现三个方法:

  • build(input_shape)这里主要是是定义权重,通过self.build=True设置哪些参数参与训练,通常通过super([Layer],self).build()调用父类的build函数完成
  • call(x)编写层的功能逻辑的地方,通常只需要关注传入的第一个参数:输入张量,除非你希望你的层支持masking,这层就是输入张量到输出张量的计算过程。
  • compute_output_shape(input_shape),如果你的层更改了输入张量的形状,这层定义输出张量的维度,这让Keras能自动推断各层的形状

问题:

  • 1.初看到自定义层都会对buildinput_shape参数产生疑问,实际上,我们在输入层会指定输入的维度,在每一层也会返回输出的维度,Keras也会根据计算图自动推断。
  • 2.重写layer的时候是否需要考虑batchsize?
    Keras的layer是一个Tensor到Tensor的映射,默认batch_size是保持不变,所以我们在Reshape变换维度时也不用传入batch_size维度

参考 keras 自定义层

最后举一个conditional layer normalization的例子
基于Conditional Layer Normalization的条件文本生成

# 自定义层需要实现三个方法
class LayerNormalization(Layer):
    """(Conditional) Layer Normalization
    hidden_*系列参数仅为有条件输入时(conditional=True)使用
    hidden_units 降维的维度,用于输入的条件矩阵过大,先降维再变换
    hidden_activation 一般采用线性激活
    """
    def __init__(
        self,
        center=True,
        scale=True,
        epsilon=None,
        conditional=False,
        hidden_units=None,
        hidden_activation='linear',
        hidden_initializer='glorot_uniform',
        **kwargs
    ):
        super(LayerNormalization, self).__init__(**kwargs)
        self.center = center
        self.scale = scale
        self.conditional = conditional
        self.hidden_units = hidden_units
        self.hidden_activation = activations.get(hidden_activation)
        self.hidden_initializer = initializers.get(hidden_initializer)
        self.epsilon = epsilon or 1e-12

    def build(self, input_shape):
        super(LayerNormalization, self).build(input_shape)  # self.built=True

        if self.conditional:
            shape = (input_shape[0][-1],)
        else:
            shape = (input_shape[-1],)

        if self.center:
            self.beta = self.add_weight(
                shape=shape, initializer='zeros', name='beta'
            )
        if self.scale:
            self.gamma = self.add_weight(
                shape=shape, initializer='ones', name='gamma'
            )

        if self.conditional:

            if self.hidden_units is not None:
                # 用于降维
                self.hidden_dense = Dense(
                    units=self.hidden_units,
                    activation=self.hidden_activation,
                    use_bias=False,
                    kernel_initializer=self.hidden_initializer
                )

            if self.center:
                self.beta_dense = Dense(
                    units=shape[0], use_bias=False, kernel_initializer='zeros'
                )
            if self.scale:
                self.gamma_dense = Dense(
                    units=shape[0], use_bias=False, kernel_initializer='zeros'
                )

    def call(self, inputs):
        """如果是条件Layer Norm,则默认以list为输入,第二个是condition
        """
        if self.conditional:
            inputs, cond = inputs
            # 用于降维
            if self.hidden_units is not None:
                cond = self.hidden_dense(cond)
            # 扩充维度保证与inputs维度相同
            for _ in range(K.ndim(inputs) - K.ndim(cond)):
                cond = K.expand_dims(cond, 1)
            if self.center:
                beta = self.beta_dense(cond) + self.beta
            if self.scale:
                gamma = self.gamma_dense(cond) + self.gamma
        else:
            if self.center:
                beta = self.beta
            if self.scale:
                gamma = self.gamma

        outputs = inputs
        if self.center:
            # layer normalization 取一个batch,一列的yi'yang
            mean = K.mean(outputs, axis=-1, keepdims=True)
            outputs = outputs - mean
        if self.scale:
            variance = K.mean(K.square(outputs), axis=-1, keepdims=True)
            std = K.sqrt(variance + self.epsilon)
            outputs = outputs / std
            outputs = outputs * gamma
        if self.center:
            outputs = outputs + beta

        return outputs
    # input_shape是一个list 定义输出维度
    def compute_output_shape(self, input_shape):
        if self.conditional:
            return input_shape[0]
        else:
            return input_shape
    # 融合当前类和父类的config
    def get_config(self):
        config = {
            'center': self.center,
            'scale': self.scale,
            'epsilon': self.epsilon,
            'conditional': self.conditional,
            'hidden_units': self.hidden_units,
            'hidden_activation': activations.serialize(self.hidden_activation),
            'hidden_initializer':
                initializers.serialize(self.hidden_initializer),
        }
        base_config = super(LayerNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))