modules.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import torch
  2. import torch.nn as nn
  3. # --------------------- Basic modules ---------------------
  4. class ConvModule(nn.Module):
  5. def __init__(self,
  6. in_dim, # in channels
  7. out_dim, # out channels
  8. kernel_size=1, # kernel size
  9. padding=0, # padding
  10. stride=1, # padding
  11. dilation=1, # dilation
  12. ):
  13. super(ConvModule, self).__init__()
  14. self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
  15. self.norm = nn.BatchNorm2d(out_dim)
  16. self.act = nn.LeakyReLU(0.1, inplace=True)
  17. def forward(self, x):
  18. return self.act(self.norm(self.conv(x)))
  19. # --------------------- ResNet modules ---------------------
  20. def conv3x3(in_planes, out_planes, stride=1):
  21. """3x3 convolution with padding"""
  22. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  23. padding=1, bias=False)
  24. def conv1x1(in_planes, out_planes, stride=1):
  25. """1x1 convolution"""
  26. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  27. class BasicBlock(nn.Module):
  28. expansion = 1
  29. def __init__(self, inplanes, planes, stride=1, downsample=None):
  30. super(BasicBlock, self).__init__()
  31. self.conv1 = conv3x3(inplanes, planes, stride)
  32. self.bn1 = nn.BatchNorm2d(planes)
  33. self.relu = nn.ReLU(inplace=True)
  34. self.conv2 = conv3x3(planes, planes)
  35. self.bn2 = nn.BatchNorm2d(planes)
  36. self.downsample = downsample
  37. self.stride = stride
  38. def forward(self, x):
  39. identity = x
  40. out = self.conv1(x)
  41. out = self.bn1(out)
  42. out = self.relu(out)
  43. out = self.conv2(out)
  44. out = self.bn2(out)
  45. if self.downsample is not None:
  46. identity = self.downsample(x)
  47. out += identity
  48. out = self.relu(out)
  49. return out
  50. class Bottleneck(nn.Module):
  51. expansion = 4
  52. def __init__(self, inplanes, planes, stride=1, downsample=None):
  53. super(Bottleneck, self).__init__()
  54. self.conv1 = conv1x1(inplanes, planes)
  55. self.bn1 = nn.BatchNorm2d(planes)
  56. self.conv2 = conv3x3(planes, planes, stride)
  57. self.bn2 = nn.BatchNorm2d(planes)
  58. self.conv3 = conv1x1(planes, planes * self.expansion)
  59. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  60. self.relu = nn.ReLU(inplace=True)
  61. self.downsample = downsample
  62. self.stride = stride
  63. def forward(self, x):
  64. identity = x
  65. out = self.conv1(x)
  66. out = self.bn1(out)
  67. out = self.relu(out)
  68. out = self.conv2(out)
  69. out = self.bn2(out)
  70. out = self.relu(out)
  71. out = self.conv3(out)
  72. out = self.bn3(out)
  73. if self.downsample is not None:
  74. identity = self.downsample(x)
  75. out += identity
  76. out = self.relu(out)
  77. return out