basic.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import torch
  2. import torch.nn as nn
  3. # ----------------- CNN modules -----------------
  4. def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
  5. conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
  6. return conv
  7. def get_activation(act_type=None):
  8. if act_type == 'relu':
  9. return nn.ReLU(inplace=True)
  10. elif act_type == 'lrelu':
  11. return nn.LeakyReLU(0.1, inplace=True)
  12. elif act_type == 'mish':
  13. return nn.Mish(inplace=True)
  14. elif act_type == 'silu':
  15. return nn.SiLU(inplace=True)
  16. elif act_type is None:
  17. return nn.Identity()
  18. else:
  19. raise NotImplementedError
  20. def get_norm(norm_type, dim):
  21. if norm_type == 'BN':
  22. return nn.BatchNorm2d(dim)
  23. elif norm_type == 'GN':
  24. return nn.GroupNorm(num_groups=32, num_channels=dim)
  25. elif norm_type is None:
  26. return nn.Identity()
  27. else:
  28. raise NotImplementedError
  29. def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  30. """3x3 convolution with padding"""
  31. return nn.Conv2d(
  32. in_planes,
  33. out_planes,
  34. kernel_size=3,
  35. stride=stride,
  36. padding=dilation,
  37. groups=groups,
  38. bias=False,
  39. dilation=dilation,
  40. )
  41. def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  42. """1x1 convolution"""
  43. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  44. class FrozenBatchNorm2d(torch.nn.Module):
  45. def __init__(self, n):
  46. super(FrozenBatchNorm2d, self).__init__()
  47. self.register_buffer("weight", torch.ones(n))
  48. self.register_buffer("bias", torch.zeros(n))
  49. self.register_buffer("running_mean", torch.zeros(n))
  50. self.register_buffer("running_var", torch.ones(n))
  51. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  52. missing_keys, unexpected_keys, error_msgs):
  53. num_batches_tracked_key = prefix + 'num_batches_tracked'
  54. if num_batches_tracked_key in state_dict:
  55. del state_dict[num_batches_tracked_key]
  56. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  57. state_dict, prefix, local_metadata, strict,
  58. missing_keys, unexpected_keys, error_msgs)
  59. def forward(self, x):
  60. # move reshapes to the beginning
  61. # to make it fuser-friendly
  62. w = self.weight.reshape(1, -1, 1, 1)
  63. b = self.bias.reshape(1, -1, 1, 1)
  64. rv = self.running_var.reshape(1, -1, 1, 1)
  65. rm = self.running_mean.reshape(1, -1, 1, 1)
  66. eps = 1e-5
  67. scale = w * (rv + eps).rsqrt()
  68. bias = b - rm * scale
  69. return x * scale + bias
  70. class Conv(nn.Module):
  71. def __init__(self,
  72. c1, # in channels
  73. c2, # out channels
  74. k=1, # kernel size
  75. p=0, # padding
  76. s=1, # padding
  77. d=1, # dilation
  78. act_type :str = 'lrelu', # activation
  79. norm_type :str ='BN', # normalization
  80. depthwise :bool =False):
  81. super(Conv, self).__init__()
  82. convs = []
  83. add_bias = False if norm_type else True
  84. if depthwise:
  85. convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
  86. # depthwise conv
  87. if norm_type:
  88. convs.append(get_norm(norm_type, c1))
  89. if act_type:
  90. convs.append(get_activation(act_type))
  91. # pointwise conv
  92. convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
  93. if norm_type:
  94. convs.append(get_norm(norm_type, c2))
  95. if act_type:
  96. convs.append(get_activation(act_type))
  97. else:
  98. convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
  99. if norm_type:
  100. convs.append(get_norm(norm_type, c2))
  101. if act_type:
  102. convs.append(get_activation(act_type))
  103. self.convs = nn.Sequential(*convs)
  104. def forward(self, x):
  105. return self.convs(x)
  106. # ----------------- Transformer modules -----------------