basic.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. # ----------------- Customed NormLayer Ops -----------------
  5. class FrozenBatchNorm2d(torch.nn.Module):
  6. def __init__(self, n):
  7. super(FrozenBatchNorm2d, self).__init__()
  8. self.register_buffer("weight", torch.ones(n))
  9. self.register_buffer("bias", torch.zeros(n))
  10. self.register_buffer("running_mean", torch.zeros(n))
  11. self.register_buffer("running_var", torch.ones(n))
  12. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  13. missing_keys, unexpected_keys, error_msgs):
  14. num_batches_tracked_key = prefix + 'num_batches_tracked'
  15. if num_batches_tracked_key in state_dict:
  16. del state_dict[num_batches_tracked_key]
  17. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  18. state_dict, prefix, local_metadata, strict,
  19. missing_keys, unexpected_keys, error_msgs)
  20. def forward(self, x):
  21. # move reshapes to the beginning
  22. # to make it fuser-friendly
  23. w = self.weight.reshape(1, -1, 1, 1)
  24. b = self.bias.reshape(1, -1, 1, 1)
  25. rv = self.running_var.reshape(1, -1, 1, 1)
  26. rm = self.running_mean.reshape(1, -1, 1, 1)
  27. eps = 1e-5
  28. scale = w * (rv + eps).rsqrt()
  29. bias = b - rm * scale
  30. return x * scale + bias
  31. class LayerNorm2D(nn.Module):
  32. def __init__(self, normalized_shape, norm_layer=nn.LayerNorm):
  33. super().__init__()
  34. self.ln = norm_layer(normalized_shape) if norm_layer is not None else nn.Identity()
  35. def forward(self, x):
  36. """
  37. x: N C H W
  38. """
  39. x = x.permute(0, 2, 3, 1)
  40. x = self.ln(x)
  41. x = x.permute(0, 3, 1, 2)
  42. return x
  43. # ----------------- Basic CNN Ops -----------------
  44. def get_conv2d(c1, c2, k, p, s, g, bias=False):
  45. conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, groups=g, bias=bias)
  46. return conv
  47. def get_activation(act_type=None):
  48. if act_type == 'relu':
  49. return nn.ReLU(inplace=True)
  50. elif act_type == 'lrelu':
  51. return nn.LeakyReLU(0.1, inplace=True)
  52. elif act_type == 'mish':
  53. return nn.Mish(inplace=True)
  54. elif act_type == 'silu':
  55. return nn.SiLU(inplace=True)
  56. elif act_type == 'gelu':
  57. return nn.GELU()
  58. elif act_type is None:
  59. return nn.Identity()
  60. else:
  61. raise NotImplementedError
  62. def get_norm(norm_type, dim):
  63. if norm_type == 'BN':
  64. return nn.BatchNorm2d(dim)
  65. elif norm_type == 'GN':
  66. return nn.GroupNorm(num_groups=32, num_channels=dim)
  67. elif norm_type is None:
  68. return nn.Identity()
  69. else:
  70. raise NotImplementedError
  71. class BasicConv(nn.Module):
  72. def __init__(self,
  73. in_dim, # in channels
  74. out_dim, # out channels
  75. kernel_size=1, # kernel size
  76. padding=0, # padding
  77. stride=1, # padding
  78. act_type :str = 'lrelu', # activation
  79. norm_type :str = 'BN', # normalization
  80. ):
  81. super(BasicConv, self).__init__()
  82. add_bias = False if norm_type else True
  83. self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
  84. self.norm = get_norm(norm_type, out_dim)
  85. self.act = get_activation(act_type)
  86. def forward(self, x):
  87. return self.act(self.norm(self.conv(x)))
  88. class UpSampleWrapper(nn.Module):
  89. """Upsample last feat map to specific stride."""
  90. def __init__(self, in_dim, upsample_factor):
  91. super(UpSampleWrapper, self).__init__()
  92. # ---------- Basic parameters ----------
  93. self.upsample_factor = upsample_factor
  94. # ---------- Network parameters ----------
  95. if upsample_factor == 1:
  96. self.upsample = nn.Identity()
  97. else:
  98. scale = int(math.log2(upsample_factor))
  99. dim = in_dim
  100. layers = []
  101. for _ in range(scale-1):
  102. layers += [
  103. nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
  104. LayerNorm2D(dim // 2),
  105. nn.GELU()
  106. ]
  107. dim = dim // 2
  108. layers += [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
  109. dim = dim // 2
  110. self.upsample = nn.Sequential(*layers)
  111. self.out_dim = dim
  112. def forward(self, x):
  113. x = self.upsample(x)
  114. return x
  115. # ----------------- MLP modules -----------------
  116. class MLP(nn.Module):
  117. def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
  118. super().__init__()
  119. self.num_layers = num_layers
  120. h = [hidden_dim] * (num_layers - 1)
  121. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
  122. def forward(self, x):
  123. for i, layer in enumerate(self.layers):
  124. x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  125. return x
  126. class FFN(nn.Module):
  127. def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
  128. super().__init__()
  129. self.fpn_dim = round(d_model * mlp_ratio)
  130. self.linear1 = nn.Linear(d_model, self.fpn_dim)
  131. self.activation = get_activation(act_type)
  132. self.dropout2 = nn.Dropout(dropout)
  133. self.linear2 = nn.Linear(self.fpn_dim, d_model)
  134. self.dropout3 = nn.Dropout(dropout)
  135. self.norm = nn.LayerNorm(d_model)
  136. def forward(self, src):
  137. src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
  138. src = src + self.dropout3(src2)
  139. src = self.norm(src)
  140. return src
  141. # ----------------- Basic CNN Ops -----------------
  142. class FrozenBatchNorm2d(torch.nn.Module):
  143. def __init__(self, n):
  144. super(FrozenBatchNorm2d, self).__init__()
  145. self.register_buffer("weight", torch.ones(n))
  146. self.register_buffer("bias", torch.zeros(n))
  147. self.register_buffer("running_mean", torch.zeros(n))
  148. self.register_buffer("running_var", torch.ones(n))
  149. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  150. missing_keys, unexpected_keys, error_msgs):
  151. num_batches_tracked_key = prefix + 'num_batches_tracked'
  152. if num_batches_tracked_key in state_dict:
  153. del state_dict[num_batches_tracked_key]
  154. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  155. state_dict, prefix, local_metadata, strict,
  156. missing_keys, unexpected_keys, error_msgs)
  157. def forward(self, x):
  158. # move reshapes to the beginning
  159. # to make it fuser-friendly
  160. w = self.weight.reshape(1, -1, 1, 1)
  161. b = self.bias.reshape(1, -1, 1, 1)
  162. rv = self.running_var.reshape(1, -1, 1, 1)
  163. rm = self.running_mean.reshape(1, -1, 1, 1)
  164. eps = 1e-5
  165. scale = w * (rv + eps).rsqrt()
  166. bias = b - rm * scale
  167. return x * scale + bias