modules.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import torch
  2. import torch.nn as nn
  3. from typing import List
  4. # --------------------- Basic modules ---------------------
  5. class BasicConv(nn.Module):
  6. def __init__(self,
  7. in_dim, # in channels
  8. out_dim, # out channels
  9. kernel_size=1, # kernel size
  10. padding=0, # padding
  11. stride=1, # padding
  12. dilation=1, # dilation
  13. ):
  14. super(BasicConv, self).__init__()
  15. self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
  16. self.norm = nn.BatchNorm2d(out_dim)
  17. self.act = nn.LeakyReLU(0.1, inplace=True)
  18. def forward(self, x):
  19. return self.act(self.norm(self.conv(x)))
  20. # --------------------- ResNet modules ---------------------
  21. def conv3x3(in_planes, out_planes, stride=1):
  22. """3x3 convolution with padding"""
  23. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  24. padding=1, bias=False)
  25. def conv1x1(in_planes, out_planes, stride=1):
  26. """1x1 convolution"""
  27. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  28. class BasicBlock(nn.Module):
  29. expansion = 1
  30. def __init__(self, inplanes, planes, stride=1, downsample=None):
  31. super(BasicBlock, self).__init__()
  32. self.conv1 = conv3x3(inplanes, planes, stride)
  33. self.bn1 = nn.BatchNorm2d(planes)
  34. self.relu = nn.ReLU(inplace=True)
  35. self.conv2 = conv3x3(planes, planes)
  36. self.bn2 = nn.BatchNorm2d(planes)
  37. self.downsample = downsample
  38. self.stride = stride
  39. def forward(self, x):
  40. identity = x
  41. out = self.conv1(x)
  42. out = self.bn1(out)
  43. out = self.relu(out)
  44. out = self.conv2(out)
  45. out = self.bn2(out)
  46. if self.downsample is not None:
  47. identity = self.downsample(x)
  48. out += identity
  49. out = self.relu(out)
  50. return out
  51. class Bottleneck(nn.Module):
  52. expansion = 4
  53. def __init__(self, inplanes, planes, stride=1, downsample=None):
  54. super(Bottleneck, self).__init__()
  55. self.conv1 = conv1x1(inplanes, planes)
  56. self.bn1 = nn.BatchNorm2d(planes)
  57. self.conv2 = conv3x3(planes, planes, stride)
  58. self.bn2 = nn.BatchNorm2d(planes)
  59. self.conv3 = conv1x1(planes, planes * self.expansion)
  60. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  61. self.relu = nn.ReLU(inplace=True)
  62. self.downsample = downsample
  63. self.stride = stride
  64. def forward(self, x):
  65. identity = x
  66. out = self.conv1(x)
  67. out = self.bn1(out)
  68. out = self.relu(out)
  69. out = self.conv2(out)
  70. out = self.bn2(out)
  71. out = self.relu(out)
  72. out = self.conv3(out)
  73. out = self.bn3(out)
  74. if self.downsample is not None:
  75. identity = self.downsample(x)
  76. out += identity
  77. out = self.relu(out)
  78. return out