modules.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import torch
  2. import torch.nn as nn
  3. from typing import List
  4. # --------------------- Basic modules ---------------------
  5. class ConvModule(nn.Module):
  6. def __init__(self,
  7. in_dim, # in channels
  8. out_dim, # out channels
  9. kernel_size=1, # kernel size
  10. padding=0, # padding
  11. stride=1, # padding
  12. dilation=1, # dilation
  13. ):
  14. super(ConvModule, self).__init__()
  15. self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
  16. self.norm = nn.BatchNorm2d(out_dim)
  17. self.act = nn.SiLU(inplace=True)
  18. def forward(self, x):
  19. return self.act(self.norm(self.conv(x)))
  20. class YoloBottleneck(nn.Module):
  21. def __init__(self,
  22. in_dim :int,
  23. out_dim :int,
  24. kernel_size :List = [1, 3],
  25. expansion :float = 0.5,
  26. shortcut :bool = False,
  27. ) -> None:
  28. super(YoloBottleneck, self).__init__()
  29. inter_dim = int(out_dim * expansion)
  30. # ----------------- Network setting -----------------
  31. self.conv_layer1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1)
  32. self.conv_layer2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1)
  33. self.shortcut = shortcut and in_dim == out_dim
  34. def forward(self, x):
  35. h = self.conv_layer2(self.conv_layer1(x))
  36. return x + h if self.shortcut else h
  37. class CSPBlock(nn.Module):
  38. def __init__(self,
  39. in_dim,
  40. out_dim,
  41. num_blocks :int = 1,
  42. expansion :float = 0.5,
  43. shortcut :bool = False,
  44. ):
  45. super(CSPBlock, self).__init__()
  46. # ---------- Basic parameters ----------
  47. self.num_blocks = num_blocks
  48. self.expansion = expansion
  49. self.shortcut = shortcut
  50. inter_dim = round(out_dim * expansion)
  51. # ---------- Model parameters ----------
  52. self.conv_layer_1 = ConvModule(in_dim, inter_dim, kernel_size=1)
  53. self.conv_layer_2 = ConvModule(in_dim, inter_dim, kernel_size=1)
  54. self.conv_layer_3 = ConvModule(inter_dim * 2, out_dim, kernel_size=1)
  55. self.module = nn.Sequential(*[
  56. YoloBottleneck(inter_dim,
  57. inter_dim,
  58. kernel_size = [1, 3],
  59. expansion = 1.0,
  60. shortcut = shortcut,
  61. ) for _ in range(num_blocks)])
  62. def forward(self, x):
  63. x1 = self.conv_layer_1(x)
  64. x2 = self.module(self.conv_layer_2(x))
  65. out = self.conv_layer_3(torch.cat([x1, x2], dim=1))
  66. return out