淘先锋技术网

首页 1 2 3 4 5 6 7

pytorch 中卷积的padding = ‘same’

最近在用pytorch做一个项目,项目中涉及到用卷积部分,平时较常用的框架是tensorflow,keras,在keras的卷积层中,经常会使用到参数padding = ‘same’,即使用“same”的填充方式,但是在pytorch的使用中,我发现pytorch是没有这种填充方式的,自己摸索了一段时间pytorch的框架,下面是用pytorch实现的conv2d中的padding=‘same’的机制。后期会对代码进行详解。

# modify con2d function to use same padding
# code referd to @famssa in 'https://github.com/pytorch/pytorch/issues/3867'
# and tensorflow source code

import torch.utils.data
from torch.nn import functional as F
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.functional import pad
from torch.nn.modules import Module
from torch.nn.modules.utils import _single, _pair, _triple

class _ConvNd(Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding, groups, bias):
        super(_ConvNd, self).__init__()
        if in_channels % groups != :
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != :
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        if transposed:
            self.weight = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv =  / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != :
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)

class Conv2d(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=,
                 padding=, dilation=, groups=, bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(), groups, bias)
    def forward(self, input):
        return conv2d_same_padding(input, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

# custom con2d, because pytorch don't have "padding='same'" option.

def conv2d_same_padding(input, weight, bias=None, stride=, padding=, dilation=, groups=):
    input_rows = input.size()
    filter_rows = weight.size()
    effective_filter_size_rows = (filter_rows - ) * dilation[] + 
    out_rows = (input_rows + stride[] - ) // stride[]
    padding_needed = max(, (out_rows - ) * stride[] + effective_filter_size_rows -input_rows)
    padding_rows = max(, (out_rows - ) * stride[] +
                        (filter_rows - ) * dilation[] +  - input_rows)
    rows_odd = (padding_rows %  != )
   # padding_cols = max(0, (out_rows - 1) * stride[0] +
                       # (filter_rows - 1) * dilation[0] + 1 - input_rows)
    padding_cols = 
    cols_odd = (padding_rows %  != )
    if rows_odd or cols_odd:
        input = pad(input, [, int(cols_odd), , int(rows_odd)])
    return F.conv2d(input, weight, bias, stride,
                  padding=(padding_rows // , padding_cols // ),
                  dilation=dilation, groups=groups)