modules.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from typing import List
  5. # --------------------- Basic modules ---------------------
  6. class ConvModule(nn.Module):
  7. def __init__(self,
  8. in_dim,
  9. out_dim,
  10. kernel_size=1,
  11. stride=1,
  12. groups=1,
  13. use_act=True,
  14. ):
  15. super(ConvModule, self).__init__()
  16. self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=groups, bias=False)
  17. self.norm = nn.BatchNorm2d(out_dim)
  18. self.act = nn.SiLU(inplace=True) if use_act else nn.Identity()
  19. def forward(self, x):
  20. return self.act(self.norm(self.conv(x)))
  21. class YoloBottleneck(nn.Module):
  22. def __init__(self,
  23. in_dim :int,
  24. out_dim :int,
  25. kernel_size :List = [1, 3],
  26. expansion :float = 0.5,
  27. shortcut :bool = False,
  28. ):
  29. super(YoloBottleneck, self).__init__()
  30. inter_dim = int(out_dim * expansion)
  31. # ----------------- Network setting -----------------
  32. self.conv_layer1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], stride=1)
  33. self.conv_layer2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], stride=1)
  34. self.shortcut = shortcut and in_dim == out_dim
  35. def forward(self, x):
  36. h = self.conv_layer2(self.conv_layer1(x))
  37. return x + h if self.shortcut else h
  38. class CIBBlock(nn.Module):
  39. def __init__(self,
  40. in_dim :int,
  41. out_dim :int,
  42. shortcut :bool = False,
  43. ) -> None:
  44. super(CIBBlock, self).__init__()
  45. # ----------------- Network setting -----------------
  46. self.cv1 = ConvModule(in_dim, in_dim, kernel_size=3, groups=in_dim)
  47. self.cv2 = ConvModule(in_dim, in_dim * 2, kernel_size=1)
  48. self.cv3 = ConvModule(in_dim * 2, in_dim * 2, kernel_size=3, groups=in_dim * 2)
  49. self.cv4 = ConvModule(in_dim * 2, out_dim, kernel_size=1)
  50. self.cv5 = ConvModule(out_dim, out_dim, kernel_size=3, groups=out_dim)
  51. self.shortcut = shortcut and in_dim == out_dim
  52. def forward(self, x):
  53. h = self.cv5(self.cv4(self.cv3(self.cv2(self.cv1(x)))))
  54. return x + h if self.shortcut else h
  55. # --------------------- Yolov10 modules ---------------------
  56. class C2fBlock(nn.Module):
  57. def __init__(self,
  58. in_dim: int,
  59. out_dim: int,
  60. expansion : float = 0.5,
  61. num_blocks : int = 1,
  62. shortcut: bool = False,
  63. use_cib: bool = False,
  64. ):
  65. super(C2fBlock, self).__init__()
  66. inter_dim = round(out_dim * expansion)
  67. self.input_proj = ConvModule(in_dim, inter_dim * 2, kernel_size=1)
  68. self.output_proj = ConvModule((2 + num_blocks) * inter_dim, out_dim, kernel_size=1)
  69. if use_cib:
  70. self.blocks = nn.ModuleList([
  71. CIBBlock(in_dim = inter_dim,
  72. out_dim = inter_dim,
  73. shortcut = shortcut,
  74. ) for _ in range(num_blocks)])
  75. else:
  76. self.blocks = nn.ModuleList([
  77. YoloBottleneck(in_dim = inter_dim,
  78. out_dim = inter_dim,
  79. kernel_size = [3, 3],
  80. expansion = 1.0,
  81. shortcut = shortcut,
  82. ) for _ in range(num_blocks)])
  83. def forward(self, x):
  84. # Input proj
  85. x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
  86. out = list([x1, x2])
  87. # Bottlenecl
  88. out.extend(m(out[-1]) for m in self.blocks)
  89. # Output proj
  90. out = self.output_proj(torch.cat(out, dim=1))
  91. return out
  92. class SCDown(nn.Module):
  93. def __init__(self, in_dim, out_dim, kernel_size: int = 3, stride: int = 2):
  94. super().__init__()
  95. self.cv1 = ConvModule(in_dim, out_dim, kernel_size=1)
  96. self.cv2 = ConvModule(out_dim, out_dim, kernel_size=kernel_size, stride=stride, groups=out_dim, use_act=False)
  97. def forward(self, x):
  98. return self.cv2(self.cv1(x))
  99. class Attention(nn.Module):
  100. def __init__(self, dim, num_heads=8, attn_ratio=0.5):
  101. super().__init__()
  102. self.num_heads = num_heads # number of the attention heads
  103. self.head_dim = dim // num_heads # per head dim of v
  104. self.key_dim = int(self.head_dim * attn_ratio) # per head dim of qk
  105. self.scale = self.key_dim**-0.5
  106. qkv_dims = dim + self.key_dim * num_heads * 2 # total dims of qkv
  107. self.qkv = ConvModule(dim, qkv_dims, kernel_size=1, use_act=False) # qkv projection
  108. self.proj = ConvModule(dim, dim, kernel_size=1, use_act=False) # output projection
  109. self.pe = ConvModule(dim, dim, kernel_size=3, groups=dim, use_act=False) # position embedding conv
  110. def forward(self, x):
  111. bs, c, h, w = x.shape
  112. seq_len = h * w
  113. qkv = self.qkv(x)
  114. # q, k -> [bs, nh, c_kdh, hw]; v -> [bs, nh, c_vh, hw]
  115. q, k, v = qkv.view(bs, self.num_heads, self.key_dim * 2 + self.head_dim, seq_len).split(
  116. [self.key_dim, self.key_dim, self.head_dim], dim=2
  117. )
  118. # [bs, nh, hw(q), c_kdh] x [bs, nh, c_kdh, hw(k)] -> [bs, nh, hw(q), hw(k)]
  119. attn = (q.transpose(-2, -1) @ k) * self.scale
  120. attn = attn.softmax(dim=-1)
  121. # [bs, nh, c_vh, hw(v)] x [bs, nh, hw(k), hw(q)] -> [bs, nh, c_vh, hw]
  122. x = (v @ attn.transpose(-2, -1)).view(bs, c, h, w) + self.pe(v.reshape(bs, c, h, w))
  123. x = self.proj(x)
  124. return x
  125. class PSABlock(nn.Module):
  126. def __init__(self, in_dim, out_dim, expansion=0.5):
  127. super().__init__()
  128. assert(in_dim == out_dim)
  129. self.inter_dim = int(in_dim * expansion)
  130. self.cv1 = ConvModule(in_dim, 2 * self.inter_dim, kernel_size=1)
  131. self.cv2 = ConvModule(2 * self.inter_dim, in_dim, kernel_size=1)
  132. self.attn = Attention(self.inter_dim, attn_ratio=0.5, num_heads=self.inter_dim // 64)
  133. self.ffn = nn.Sequential(
  134. ConvModule(self.inter_dim, self.inter_dim * 2, kernel_size=1),
  135. ConvModule(self.inter_dim * 2, self.inter_dim, kernel_size=1, use_act=False)
  136. )
  137. def forward(self, x):
  138. a, b = self.cv1(x).split((self.inter_dim, self.inter_dim), dim=1)
  139. b = b + self.attn(b)
  140. b = b + self.ffn(b)
  141. return self.cv2(torch.cat((a, b), 1))
  142. class SPPF(nn.Module):
  143. """
  144. This code referenced to https://github.com/ultralytics/yolov5
  145. """
  146. def __init__(self, in_dim, out_dim):
  147. super().__init__()
  148. ## ----------- Basic Parameters -----------
  149. inter_dim = in_dim // 2
  150. self.out_dim = out_dim
  151. ## ----------- Network Parameters -----------
  152. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1, stride=1)
  153. self.cv2 = ConvModule(inter_dim * 4, out_dim, kernel_size=1, stride=1)
  154. self.m = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
  155. # Initialize all layers
  156. self.init_weights()
  157. def init_weights(self):
  158. """Initialize the parameters."""
  159. for m in self.modules():
  160. if isinstance(m, torch.nn.Conv2d):
  161. m.reset_parameters()
  162. def forward(self, x):
  163. x = self.cv1(x)
  164. y1 = self.m(x)
  165. y2 = self.m(y1)
  166. return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
  167. class DflLayer(nn.Module):
  168. def __init__(self, reg_max=16):
  169. """Initialize a convolutional layer with a given number of input channels."""
  170. super().__init__()
  171. self.reg_max = reg_max
  172. proj_init = torch.arange(reg_max, dtype=torch.float)
  173. self.proj_weight = nn.Parameter(proj_init.view([1, reg_max, 1, 1]), requires_grad=False)
  174. def forward(self, pred_reg, anchor, stride):
  175. bs, hw = pred_reg.shape[:2]
  176. # [bs, hw, 4*rm] -> [bs, 4*rm, hw] -> [bs, 4, rm, hw]
  177. pred_reg = pred_reg.permute(0, 2, 1).reshape(bs, 4, -1, hw)
  178. # [bs, 4, rm, hw] -> [bs, rm, 4, hw]
  179. pred_reg = pred_reg.permute(0, 2, 1, 3).contiguous()
  180. # [bs, rm, 4, hw] -> [bs, 1, 4, hw]
  181. delta_pred = F.conv2d(F.softmax(pred_reg, dim=1), self.proj_weight)
  182. # [bs, 1, 4, hw] -> [bs, 4, hw] -> [bs, hw, 4]
  183. delta_pred = delta_pred.view(bs, 4, hw).permute(0, 2, 1).contiguous()
  184. delta_pred *= stride
  185. # Decode bbox: tlbr -> xyxy
  186. x1y1_pred = anchor - delta_pred[..., :2]
  187. x2y2_pred = anchor + delta_pred[..., 2:]
  188. box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
  189. return box_pred