1.论文阅读 Non-local Neural Networks
1.问题来源
CNN中的 convolution单元每次只关注邻域 kernel size 的区域,就算后期感受野越来越大,终究还是局部区域的运算,这样就忽略了全局其他片区(比如很远的像素)对当前区域的贡献。
2.主要思想
non-local blocks 要做的是,捕获这种 long-range 关系:对于2D图像,就是图像中任何像素对当前像素的关系权值;对于3D视频,就是所有帧中的所有像素,对当前帧的像素的关系权值
3.具体实现
- 关系函数
f(xi,xj)关系函数,表示对2者之间的关系进行建模
文章给出4种关系:
- 具体步骤:
总结:
Non-local可以捕获较远空间和较长时间视频中相关物体之间的相关性,即使在较长的输入时间序列下也能保持良好的泛化性。且该模块易于嵌入任何其他模块。在kinetics数据集上,3D模块中加入non-local可以提升1.6个点。
代码
参考链接:
https://github.com/yangyang12315/Action-Recognition-Module/blob/master/Non_Local.py
https://github.com/mit-han-lab/temporal-shift-module/blob/master/ops/non_local.py
(non_local.py)
import torch
from torch import nn
from torch.nn import functional as F
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1) #(b,thw,c)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)#(b,thw,c)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)#(b,c,thw)
f = torch.matmul(theta_x, phi_x) #(b,thw,thw)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)#(b,thw,c) [3x3][3x8]=[3x8]
y = y.permute(0, 2, 1).contiguous() #(b,c,thw)
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)#1*1*1 #(b, c, t, h, w)
z = W_y + x
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, sub_sample=sub_sample,
bn_layer=bn_layer)
##NONLocalBlock3D继承父类 _NonLocalBlockND
#super(NONLocalBlock3D, self).__init__赋值为准
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, sub_sample=sub_sample,
bn_layer=bn_layer)
class NL3DWrapper(nn.Module):
def __init__(self, block, n_segment):
super(NL3DWrapper, self).__init__()
self.block = block
self.nl = NONLocalBlock3D(block.bn3.num_features) #512
self.n_segment = n_segment
def forward(self, x):
x = self.block(x) #对block输出结果进行non_local与tsm不同,tsm是先tsm,再输入到网络
nt, c, h, w = x.size()
x = x.view(nt // self.n_segment, self.n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w
x = self.nl(x)
x = x.transpose(1, 2).contiguous().view(nt, c, h, w)
return x
#只对resnet的部分层和部分block加入non_local结构
def make_non_local(net, n_segment):
import torchvision
import archs
if isinstance(net, torchvision.models.ResNet) or isinstance(net, archs.small_resnet.ResNet):
net.layer2 = nn.Sequential(
NL3DWrapper(net.layer2[0], n_segment),
net.layer2[1],
NL3DWrapper(net.layer2[2], n_segment),
net.layer2[3],
)
net.layer3 = nn.Sequential(
NL3DWrapper(net.layer3[0], n_segment),
net.layer3[1],
NL3DWrapper(net.layer3[2], n_segment),
net.layer3[3],
NL3DWrapper(net.layer3[4], n_segment),
net.layer3[5],
)
else:
raise NotImplementedError
if __name__ == '__main__':
from torch.autograd import Variable
import torch
sub_sample = True
bn_layer = True
img = Variable(torch.zeros(2, 3, 20))
net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
out = net(img)
print(out.size())
img = Variable(torch.zeros(2, 3, 20, 20))
net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
out = net(img)
print(out.size())
img = Variable(torch.randn(2, 3, 10, 20, 20))
net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
out = net(img)
print(out.size())
### 调用 以resnet50为例
import torchvision
from non_local import make_non_local
ll_net = getattr(torchvision.models,'resnet50')(True)
make_non_local(ll_net,num_segment) #num_segment:num_frame/each video