basic.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import torch
  2. import torch.nn as nn
  3. def get_activation(act_type=None):
  4. if act_type == 'relu':
  5. return nn.ReLU(inplace=True)
  6. elif act_type == 'lrelu':
  7. return nn.LeakyReLU(0.1, inplace=True)
  8. elif act_type == 'mish':
  9. return nn.Mish(inplace=True)
  10. elif act_type == 'silu':
  11. return nn.SiLU(inplace=True)
  12. elif act_type == 'gelu':
  13. return nn.GELU()
  14. elif act_type is None:
  15. return nn.Identity()
  16. else:
  17. raise NotImplementedError
  18. def get_norm(norm_type, dim):
  19. if norm_type == 'BN':
  20. return nn.BatchNorm2d(dim)
  21. elif norm_type == 'GN':
  22. return nn.GroupNorm(num_groups=32, num_channels=dim)
  23. elif norm_type is None:
  24. return nn.Identity()
  25. else:
  26. raise NotImplementedError
  27. # ----------------- MLP modules -----------------
  28. class MLP(nn.Module):
  29. def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
  30. super().__init__()
  31. self.num_layers = num_layers
  32. h = [hidden_dim] * (num_layers - 1)
  33. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
  34. def forward(self, x):
  35. for i, layer in enumerate(self.layers):
  36. x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  37. return x
  38. class FFN(nn.Module):
  39. def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
  40. super().__init__()
  41. self.fpn_dim = round(d_model * mlp_ratio)
  42. self.linear1 = nn.Linear(d_model, self.fpn_dim)
  43. self.activation = get_activation(act_type)
  44. self.dropout2 = nn.Dropout(dropout)
  45. self.linear2 = nn.Linear(self.fpn_dim, d_model)
  46. self.dropout3 = nn.Dropout(dropout)
  47. self.norm = nn.LayerNorm(d_model)
  48. def forward(self, src):
  49. src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
  50. src = src + self.dropout3(src2)
  51. src = self.norm(src)
  52. return src
  53. # ----------------- Basic CNN Ops -----------------
  54. class FrozenBatchNorm2d(torch.nn.Module):
  55. def __init__(self, n):
  56. super(FrozenBatchNorm2d, self).__init__()
  57. self.register_buffer("weight", torch.ones(n))
  58. self.register_buffer("bias", torch.zeros(n))
  59. self.register_buffer("running_mean", torch.zeros(n))
  60. self.register_buffer("running_var", torch.ones(n))
  61. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  62. missing_keys, unexpected_keys, error_msgs):
  63. num_batches_tracked_key = prefix + 'num_batches_tracked'
  64. if num_batches_tracked_key in state_dict:
  65. del state_dict[num_batches_tracked_key]
  66. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  67. state_dict, prefix, local_metadata, strict,
  68. missing_keys, unexpected_keys, error_msgs)
  69. def forward(self, x):
  70. # move reshapes to the beginning
  71. # to make it fuser-friendly
  72. w = self.weight.reshape(1, -1, 1, 1)
  73. b = self.bias.reshape(1, -1, 1, 1)
  74. rv = self.running_var.reshape(1, -1, 1, 1)
  75. rm = self.running_mean.reshape(1, -1, 1, 1)
  76. eps = 1e-5
  77. scale = w * (rv + eps).rsqrt()
  78. bias = b - rm * scale
  79. return x * scale + bias