modules.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import torch
  2. import torch.nn as nn
  3. from typing import List
  4. # --------------------- Basic modules ---------------------
  5. class ConvModule(nn.Module):
  6. def __init__(self,
  7. in_dim, # in channels
  8. out_dim, # out channels
  9. kernel_size=1, # kernel size
  10. padding=0, # padding
  11. stride=1, # padding
  12. dilation=1, # dilation
  13. ):
  14. super(ConvModule, self).__init__()
  15. self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
  16. self.norm = nn.BatchNorm2d(out_dim)
  17. self.act = nn.SiLU(inplace=True)
  18. def forward(self, x):
  19. return self.act(self.norm(self.conv(x)))
  20. # ---------------------------- Basic Modules ----------------------------
  21. class YoloBottleneck(nn.Module):
  22. def __init__(self,
  23. in_dim :int,
  24. out_dim :int,
  25. kernel_size :List = [1, 3],
  26. expansion :float = 0.5,
  27. shortcut :bool = False,
  28. ) -> None:
  29. super(YoloBottleneck, self).__init__()
  30. inter_dim = int(out_dim * expansion)
  31. # ----------------- Network setting -----------------
  32. self.conv_layer1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1)
  33. self.conv_layer2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1)
  34. self.shortcut = shortcut and in_dim == out_dim
  35. def forward(self, x):
  36. h = self.conv_layer2(self.conv_layer1(x))
  37. return x + h if self.shortcut else h
  38. class CSPBlock(nn.Module):
  39. def __init__(self,
  40. in_dim,
  41. out_dim,
  42. num_blocks :int = 1,
  43. expansion :float = 0.5,
  44. shortcut :bool = False,
  45. ):
  46. super(CSPBlock, self).__init__()
  47. # ---------- Basic parameters ----------
  48. self.num_blocks = num_blocks
  49. self.expansion = expansion
  50. self.shortcut = shortcut
  51. inter_dim = round(out_dim * expansion)
  52. # ---------- Model parameters ----------
  53. self.conv_layer_1 = ConvModule(in_dim, inter_dim, kernel_size=1)
  54. self.conv_layer_2 = ConvModule(in_dim, inter_dim, kernel_size=1)
  55. self.conv_layer_3 = ConvModule(inter_dim * 2, out_dim, kernel_size=1)
  56. self.module = nn.Sequential(*[
  57. YoloBottleneck(inter_dim,
  58. inter_dim,
  59. kernel_size = [1, 3],
  60. expansion = 1.0,
  61. shortcut = shortcut,
  62. ) for _ in range(num_blocks)])
  63. def forward(self, x):
  64. x1 = self.conv_layer_1(x)
  65. x2 = self.module(self.conv_layer_2(x))
  66. out = self.conv_layer_3(torch.cat([x1, x2], dim=1))
  67. return out