淘先锋技术网

首页 1 2 3 4 5 6 7

#Copyright 2008-2021 Jacky Zong .All rights reserved .
# 殿下,
#为你战死是我至高无上的荣耀! -《离思五首》

import tensorflow as tf

from model import layers

__all__ = ['conv2d_block']


def conv2d_block(
    inputs,
    n_channels,
    kernel_size=(3, 3),
    strides=(2, 2),
    mode='SAME',
    use_batch_norm=True,
    activation='relu',
    is_training=True,
    data_format='NHWC',
    conv2d_hparams=None,
    batch_norm_hparams=None,
    name='conv2d',
    cardinality=1,
):

    if not isinstance(conv2d_hparams, tf.contrib.training.HParams):
        raise ValueError("The paramater `conv2d_hparams` is not of type `HParams`")

    if not isinstance(batch_norm_hparams, tf.contrib.training.HParams) and use_batch_norm:
        raise ValueError("The paramater `conv2d_hparams` is not of type `HParams`")

    with tf.variable_scope(name):
        if cardinality == 1:
            net = layers.conv2d(
                inputs,
                n_channels=n_channels,
                kernel_size=kernel_size,
                strides=strides,
                padding=mode,
                data_format=data_format,
                use_bias=not use_batch_norm,
                trainable=is_training,
                kernel_initializer=conv2d_hparams.kernel_initializer,
                bias_initializer=conv2d_hparams.bias_initializer)
        else:
            group_filter = tf.get_variable(
                name=name + 'group_filter',
                shape=[3, 3, n_channels // cardinality, n_channels],
                trainable=is_training,
                dtype=tf.float32)
            net = tf.nn.conv2d(inputs,
                                      group_filter,
                                      strides=strides,
                                      padding='SAME',
                                      data_format=data_format)
        if use_batch_norm:
            net = layers.batch_norm(
                net,
                decay=batch_norm_hparams.decay,
                epsilon=batch_norm_hparams.epsilon,
                scale=batch_norm_hparams.scale,
                center=batch_norm_hparams.center,
                is_training=is_training,
                data_format=data_format,
                param_initializers=batch_norm_hparams.param_initializers
            )

        if activation == 'relu':
            net = layers.relu(net, name='relu')

        elif activation == 'tanh':
            net = layers.tanh(net, name='tanh')

        elif activation != 'linear' and activation is not None:
            raise KeyError('Invalid activation type: `%s`' % activation)

        return net