modules.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import torch
  2. import torch.nn as nn
  3. from typing import List
  4. # --------------------- Basic modules ---------------------
  5. class ConvModule(nn.Module):
  6. def __init__(self,
  7. in_dim, # in channels
  8. out_dim, # out channels
  9. kernel_size=1, # kernel size
  10. padding=0, # padding
  11. stride=1, # padding
  12. dilation=1, # dilation
  13. ):
  14. super(ConvModule, self).__init__()
  15. self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
  16. self.norm = nn.BatchNorm2d(out_dim)
  17. self.act = nn.SiLU(inplace=True)
  18. def forward(self, x):
  19. return self.act(self.norm(self.conv(x)))
  20. class YoloBottleneck(nn.Module):
  21. def __init__(self,
  22. in_dim :int,
  23. out_dim :int,
  24. kernel_size :List = [1, 3],
  25. expansion :float = 0.5,
  26. shortcut :bool = False,
  27. ) -> None:
  28. super(YoloBottleneck, self).__init__()
  29. inter_dim = int(out_dim * expansion)
  30. # ----------------- Network setting -----------------
  31. self.conv_layer1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1)
  32. self.conv_layer2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1)
  33. self.shortcut = shortcut and in_dim == out_dim
  34. def forward(self, x):
  35. h = self.conv_layer2(self.conv_layer1(x))
  36. return x + h if self.shortcut else h
  37. class ResBlock(nn.Module):
  38. def __init__(self,
  39. in_dim,
  40. out_dim,
  41. num_blocks :int = 1,
  42. expansion :float = 0.5,
  43. shortcut :bool = False,
  44. ):
  45. super(ResBlock, self).__init__()
  46. # ---------- Basic parameters ----------
  47. self.num_blocks = num_blocks
  48. self.expansion = expansion
  49. self.shortcut = shortcut
  50. # ---------- Model parameters ----------
  51. module = []
  52. for i in range(num_blocks):
  53. if i == 0:
  54. module.append(YoloBottleneck(in_dim = in_dim,
  55. out_dim = out_dim,
  56. kernel_size = [1, 3],
  57. expansion = expansion,
  58. shortcut = shortcut,
  59. ))
  60. else:
  61. module.append(YoloBottleneck(in_dim = out_dim,
  62. out_dim = out_dim,
  63. kernel_size = [1, 3],
  64. expansion = expansion,
  65. shortcut = shortcut,
  66. ))
  67. self.module = nn.Sequential(*module)
  68. def forward(self, x):
  69. out = self.module(x)
  70. return out