modules.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import torch
  2. import torch.nn as nn
  3. from typing import List
  4. # --------------------- Basic modules ---------------------
  5. class ConvModule(nn.Module):
  6. def __init__(self,
  7. in_dim, # in channels
  8. out_dim, # out channels
  9. kernel_size=1, # kernel size
  10. padding=0, # padding
  11. stride=1, # padding
  12. dilation=1, # dilation
  13. ):
  14. super(ConvModule, self).__init__()
  15. self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
  16. self.norm = nn.BatchNorm2d(out_dim)
  17. self.act = nn.SiLU(inplace=True)
  18. def forward(self, x):
  19. return self.act(self.norm(self.conv(x)))
  20. class YoloBottleneck(nn.Module):
  21. def __init__(self,
  22. in_dim :int,
  23. out_dim :int,
  24. kernel_size :List = [1, 3],
  25. expansion :float = 0.5,
  26. shortcut :bool = False,
  27. ):
  28. super(YoloBottleneck, self).__init__()
  29. inter_dim = int(out_dim * expansion)
  30. # ----------------- Network setting -----------------
  31. self.conv_layer1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1)
  32. self.conv_layer2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1)
  33. self.shortcut = shortcut and in_dim == out_dim
  34. def forward(self, x):
  35. h = self.conv_layer2(self.conv_layer1(x))
  36. return x + h if self.shortcut else h
  37. class C2fBlock(nn.Module):
  38. def __init__(self,
  39. in_dim: int,
  40. out_dim: int,
  41. expansion : float = 0.5,
  42. num_blocks : int = 1,
  43. shortcut : bool = False,
  44. ):
  45. super(C2fBlock, self).__init__()
  46. inter_dim = round(out_dim * expansion)
  47. self.input_proj = ConvModule(in_dim, inter_dim * 2, kernel_size=1)
  48. self.output_proj = ConvModule((2 + num_blocks) * inter_dim, out_dim, kernel_size=1)
  49. self.module = nn.ModuleList([
  50. YoloBottleneck(in_dim = inter_dim,
  51. out_dim = inter_dim,
  52. kernel_size = [3, 3],
  53. expansion = 1.0,
  54. shortcut = shortcut,
  55. ) for _ in range(num_blocks)])
  56. def forward(self, x):
  57. # Input proj
  58. x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
  59. out = list([x1, x2])
  60. # Bottlenecl
  61. out.extend(m(out[-1]) for m in self.module)
  62. # Output proj
  63. out = self.output_proj(torch.cat(out, dim=1))
  64. return out