modules.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import torch
  2. import torch.nn as nn
  3. # --------------------- Basic modules ---------------------
  4. class ConvModule(nn.Module):
  5. def __init__(self,
  6. in_dim: int, # in channels
  7. out_dim: int, # out channels
  8. kernel_size: int = 1, # kernel size
  9. stride:int = 1, # padding
  10. ):
  11. super(ConvModule, self).__init__()
  12. convs = []
  13. convs.append(nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size//2, stride=stride, bias=False))
  14. convs.append(nn.BatchNorm2d(out_dim))
  15. convs.append(nn.SiLU(inplace=True))
  16. self.convs = nn.Sequential(*convs)
  17. def forward(self, x):
  18. return self.convs(x)
  19. class Bottleneck(nn.Module):
  20. def __init__(self,
  21. in_dim: int,
  22. out_dim: int,
  23. expand_ratio: float = 0.5,
  24. shortcut: bool = False,
  25. ):
  26. super(Bottleneck, self).__init__()
  27. inter_dim = int(out_dim * expand_ratio) # hidden channels
  28. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
  29. self.cv2 = ConvModule(inter_dim, out_dim, kernel_size=3, stride=1)
  30. self.shortcut = shortcut and in_dim == out_dim
  31. def forward(self, x):
  32. h = self.cv2(self.cv1(x))
  33. return x + h if self.shortcut else h
  34. class CSPBlock(nn.Module):
  35. def __init__(self,
  36. in_dim: int,
  37. out_dim: int,
  38. expand_ratio: float = 0.5,
  39. num_blocks: int = 1,
  40. shortcut: bool = False,
  41. ):
  42. super(CSPBlock, self).__init__()
  43. inter_dim = int(out_dim * expand_ratio)
  44. self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
  45. self.cv2 = ConvModule(in_dim, inter_dim, kernel_size=1)
  46. self.cv3 = ConvModule(2 * inter_dim, out_dim, kernel_size=1)
  47. self.m = nn.Sequential(*[
  48. Bottleneck(inter_dim, inter_dim, expand_ratio=1.0, shortcut=shortcut)
  49. for _ in range(num_blocks)
  50. ])
  51. def forward(self, x):
  52. x1 = self.cv1(x)
  53. x2 = self.cv2(x)
  54. x3 = self.m(x1)
  55. out = self.cv3(torch.cat([x3, x2], dim=1))
  56. return out