modules.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from typing import List
  5. # ----------------- CNN modules -----------------
  6. class ConvModule(nn.Module):
  7. def __init__(self,
  8. in_dim, # in channels
  9. out_dim, # out channels
  10. kernel_size=1, # kernel size
  11. stride=1, # padding
  12. groups=1, # groups
  13. use_act: bool = True,
  14. ):
  15. super(ConvModule, self).__init__()
  16. self.conv = nn.Conv2d(in_dim, out_dim, kernel_size, padding=kernel_size//2, stride=stride, groups=groups, bias=False)
  17. self.norm = nn.BatchNorm2d(out_dim)
  18. self.act = nn.SiLU(inplace=True) if use_act else nn.Identity()
  19. def forward(self, x):
  20. return self.act(self.norm(self.conv(x)))
  21. class Bottleneck(nn.Module):
  22. def __init__(self,
  23. in_dim :int,
  24. out_dim :int,
  25. kernel_size :List = [3, 3],
  26. shortcut :bool = False,
  27. expansion :float = 0.5,
  28. ) -> None:
  29. super(Bottleneck, self).__init__()
  30. # ----------------- Network setting -----------------
  31. inter_dim = int(out_dim * expansion)
  32. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], stride=1)
  33. self.cv2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], stride=1)
  34. self.shortcut = shortcut and in_dim == out_dim
  35. def forward(self, x):
  36. h = self.cv2(self.cv1(x))
  37. return x + h if self.shortcut else h
  38. class C3kBlock(nn.Module):
  39. def __init__(self,
  40. in_dim: int,
  41. out_dim: int,
  42. num_blocks: int = 1,
  43. shortcut: bool = True,
  44. expansion: float = 0.5,
  45. ):
  46. super().__init__()
  47. inter_dim = int(out_dim * expansion) # hidden channels
  48. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
  49. self.cv2 = ConvModule(in_dim, inter_dim, kernel_size=1)
  50. self.cv3 = ConvModule(2 * inter_dim, out_dim, kernel_size=1) # optional act=FReLU(c2)
  51. self.m = nn.Sequential(*[
  52. Bottleneck(in_dim = inter_dim,
  53. out_dim = inter_dim,
  54. kernel_size = [3, 3],
  55. shortcut = shortcut,
  56. expansion = 1.0,
  57. ) for _ in range(num_blocks)])
  58. def forward(self, x):
  59. return self.cv3(torch.cat([self.m(self.cv1(x)), self.cv2(x)], dim=1))
  60. class SPPF(nn.Module):
  61. def __init__(self, in_dim, out_dim, spp_pooling_size: int = 5, neck_expand_ratio:float = 0.5):
  62. super().__init__()
  63. ## ----------- Basic Parameters -----------
  64. inter_dim = round(in_dim * neck_expand_ratio)
  65. self.out_dim = out_dim
  66. ## ----------- Network Parameters -----------
  67. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1, stride=1)
  68. self.cv2 = ConvModule(inter_dim * 4, out_dim, kernel_size=1, stride=1)
  69. self.m = nn.MaxPool2d(kernel_size=spp_pooling_size, stride=1, padding=spp_pooling_size // 2)
  70. def forward(self, x):
  71. x = self.cv1(x)
  72. y1 = self.m(x)
  73. y2 = self.m(y1)
  74. return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
  75. # ----------------- Attention modules -----------------
  76. class Attention(nn.Module):
  77. def __init__(self, dim, num_heads=8, attn_ratio=0.5):
  78. super().__init__()
  79. self.num_heads = num_heads
  80. self.head_dim = dim // num_heads
  81. self.key_dim = int(self.head_dim * attn_ratio)
  82. self.scale = self.key_dim**-0.5
  83. nh_kd = self.key_dim * num_heads
  84. h = dim + nh_kd * 2
  85. self.qkv = ConvModule(dim, h, kernel_size=1, use_act=False)
  86. self.proj = ConvModule(dim, dim, kernel_size=1, use_act=False)
  87. self.pe = ConvModule(dim, dim, kernel_size=3, groups=dim, use_act=False)
  88. def forward(self, x):
  89. bs, c, h, w = x.shape
  90. seq_len = h * w
  91. qkv = self.qkv(x)
  92. q, k, v = qkv.view(bs, self.num_heads, self.key_dim * 2 + self.head_dim, seq_len).split(
  93. [self.key_dim, self.key_dim, self.head_dim], dim=2
  94. )
  95. attn = (q.transpose(-2, -1) @ k) * self.scale
  96. attn = attn.softmax(dim=-1)
  97. x = (v @ attn.transpose(-2, -1)).view(bs, c, h, w) + self.pe(v.reshape(bs, c, h, w))
  98. x = self.proj(x)
  99. return x
  100. class PSABlock(nn.Module):
  101. def __init__(self, in_dim, attn_ratio=0.5, num_heads=4, shortcut=True):
  102. super().__init__()
  103. self.attn = Attention(in_dim, attn_ratio=attn_ratio, num_heads=num_heads)
  104. self.ffn = nn.Sequential(ConvModule(in_dim, in_dim * 2, kernel_size=1),
  105. ConvModule(in_dim * 2, in_dim, kernel_size=1, use_act=False))
  106. self.add = shortcut
  107. def forward(self, x):
  108. x = x + self.attn(x) if self.add else self.attn(x)
  109. x = x + self.ffn(x) if self.add else self.ffn(x)
  110. return x
  111. class C2PSA(nn.Module):
  112. def __init__(self, in_dim, out_dim, num_blocks=1, expansion=0.5):
  113. super().__init__()
  114. assert in_dim == out_dim
  115. inter_dim = int(in_dim * expansion)
  116. self.cv1 = ConvModule(in_dim, 2 * inter_dim, kernel_size=1)
  117. self.cv2 = ConvModule(2 * inter_dim, in_dim, kernel_size=1)
  118. self.m = nn.Sequential(*[
  119. PSABlock(in_dim = inter_dim,
  120. attn_ratio = 0.5,
  121. num_heads = inter_dim // 64
  122. ) for _ in range(num_blocks)])
  123. def forward(self, x):
  124. x1, x2 = torch.chunk(self.cv1(x), chunks=2, dim=1)
  125. x2 = self.m(x2)
  126. return self.cv2(torch.cat([x1, x2], dim=1))
  127. # ----------------- YOLO11 components -----------------
  128. class C3k2fBlock(nn.Module):
  129. def __init__(self, in_dim, out_dim, num_blocks=1, use_c3k=True, expansion=0.5, shortcut=True):
  130. super().__init__()
  131. inter_dim = int(out_dim * expansion) # hidden channels
  132. self.cv1 = ConvModule(in_dim, 2 * inter_dim, kernel_size=1)
  133. self.cv2 = ConvModule((2 + num_blocks) * inter_dim, out_dim, kernel_size=1)
  134. if use_c3k:
  135. self.m = nn.ModuleList(
  136. C3kBlock(inter_dim, inter_dim, 2, shortcut)
  137. for _ in range(num_blocks)
  138. )
  139. else:
  140. self.m = nn.ModuleList(
  141. Bottleneck(inter_dim, inter_dim, [3, 3], shortcut, expansion=0.5)
  142. for _ in range(num_blocks)
  143. )
  144. def _forward_impl(self, x):
  145. # Input proj
  146. x1, x2 = torch.chunk(self.cv1(x), 2, dim=1)
  147. out = list([x1, x2])
  148. # Bottlenecl
  149. out.extend(m(out[-1]) for m in self.m)
  150. # Output proj
  151. out = self.cv2(torch.cat(out, dim=1))
  152. return out
  153. def forward(self, x):
  154. return self._forward_impl(x)
  155. class DflLayer(nn.Module):
  156. def __init__(self, reg_max=16):
  157. """Initialize a convolutional layer with a given number of input channels."""
  158. super().__init__()
  159. self.reg_max = reg_max
  160. proj_init = torch.arange(reg_max, dtype=torch.float)
  161. self.proj_weight = nn.Parameter(proj_init.view([1, reg_max, 1, 1]), requires_grad=False)
  162. def forward(self, pred_reg, anchor, stride):
  163. bs, hw = pred_reg.shape[:2]
  164. # [bs, hw, 4*rm] -> [bs, 4*rm, hw] -> [bs, 4, rm, hw]
  165. pred_reg = pred_reg.permute(0, 2, 1).reshape(bs, 4, -1, hw)
  166. # [bs, 4, rm, hw] -> [bs, rm, 4, hw]
  167. pred_reg = pred_reg.permute(0, 2, 1, 3).contiguous()
  168. # [bs, rm, 4, hw] -> [bs, 1, 4, hw]
  169. delta_pred = F.conv2d(F.softmax(pred_reg, dim=1), self.proj_weight)
  170. # [bs, 1, 4, hw] -> [bs, 4, hw] -> [bs, hw, 4]
  171. delta_pred = delta_pred.view(bs, 4, hw).permute(0, 2, 1).contiguous()
  172. delta_pred *= stride
  173. # Decode bbox: tlbr -> xyxy
  174. x1y1_pred = anchor - delta_pred[..., :2]
  175. x2y2_pred = anchor + delta_pred[..., 2:]
  176. box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
  177. return box_pred