import torch
import torch.nn as nn
import torch
from torchsummary import summary
# from tensorboardX import summary
class CSDN_Tem(nn.Module):
def __init__(self, in_ch, out_ch):
super(CSDN_Tem, self).__init__()
self.depth_conv = nn.Conv2d(
in_channels=in_ch,
out_channels=in_ch,
kernel_size=3,
stride=1,
padding=1,
groups=in_ch
)
self.point_conv = nn.Conv2d(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=1,
stride=1,
padding=0,
groups=1
)
def forward(self, input):
out = self.depth_conv(input)
out = self.point_conv(out)
return out
conv = CSDN_Tem(16,64)
print(conv.depth_conv)
print(summary(conv,(16,1