convnet.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule
  5. except:
  6. from modules import ConvModule
  7. # Convolutional Network
  8. class ConvNet(nn.Module):
  9. def __init__(self,
  10. img_size :int = 224,
  11. in_dim :int = 3,
  12. hidden_dim :int = 16,
  13. num_classes :int = 10,
  14. act_type :str = "relu",
  15. norm_type :str = "bn",
  16. depthwise :bool = False,
  17. use_adavgpool :bool = True,
  18. ) -> None:
  19. super().__init__()
  20. # ---------- Basic parameters ----------
  21. self.img_size = img_size
  22. self.num_classes = num_classes
  23. self.act_type = act_type
  24. self.norm_type = norm_type
  25. self.use_adavgpool = use_adavgpool
  26. self.layer_dims = [hidden_dim, hidden_dim*2, hidden_dim*4, hidden_dim*4]
  27. # ---------- Model parameters ----------
  28. self.layer_1 = nn.Sequential(
  29. ConvModule(in_dim, hidden_dim,
  30. kernel_size=3, padding=1, stride=2,
  31. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  32. ConvModule(hidden_dim, hidden_dim,
  33. kernel_size=3, padding=1, stride=1,
  34. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  35. )
  36. self.layer_2 = nn.Sequential(
  37. nn.MaxPool2d(kernel_size=2, stride=2),
  38. ConvModule(hidden_dim, hidden_dim * 2,
  39. kernel_size=3, padding=1, stride=1,
  40. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  41. ConvModule(hidden_dim * 2, hidden_dim * 2,
  42. kernel_size=3, padding=1, stride=1,
  43. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  44. )
  45. self.layer_3 = nn.Sequential(
  46. nn.MaxPool2d(kernel_size=2, stride=2),
  47. ConvModule(hidden_dim * 2, hidden_dim * 4,
  48. kernel_size=3, padding=1, stride=1,
  49. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  50. ConvModule(hidden_dim * 4, hidden_dim * 4,
  51. kernel_size=3, padding=1, stride=1,
  52. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  53. )
  54. self.layer_4 = nn.Sequential(
  55. ConvModule(hidden_dim * 4, hidden_dim * 4,
  56. kernel_size=3, padding=1, stride=1,
  57. act_type=act_type, norm_type=norm_type, depthwise=depthwise),
  58. ConvModule(hidden_dim * 4, hidden_dim * 4,
  59. kernel_size=3, padding=1, stride=1,
  60. act_type=act_type, norm_type=norm_type, depthwise=depthwise)
  61. )
  62. if use_adavgpool:
  63. self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
  64. self.fc = nn.Linear(hidden_dim * 4, num_classes)
  65. else:
  66. self.avgpool = None
  67. fc_in_dim = (img_size // 8) ** 2 * (hidden_dim * 4) # N = Co x Ho x W
  68. self.fc = nn.Linear(fc_in_dim , num_classes)
  69. def forward(self, x):
  70. """
  71. Input:
  72. x : (torch.Tensor) -> [B, C, H, W]
  73. Output:
  74. x : (torch.Tensor) -> [B, Nc], Nc is the number of the object categories.
  75. """
  76. # [B, C_in, H, W] -> [B, C1, H/2, W/2]
  77. x = self.layer_1(x)
  78. # [B, C1, H/2, W/2] -> [B, C2, H/4, W/4]
  79. x = self.layer_2(x)
  80. # [B, C2, H/4, W/4] -> [B, C3, H/8, W/8]
  81. x = self.layer_3(x)
  82. # [B, C3, H/8, W/8] -> [B, C3, H/8, W/8]
  83. x = self.layer_4(x)
  84. if self.use_adavgpool:
  85. x = self.avgpool(x)
  86. # reshape [B, Co, Ho, Wo] to [B, N], N = Co x Ho x Wo
  87. x = x.flatten(1)
  88. x = self.fc(x)
  89. return x
  90. if __name__ == "__main__":
  91. bs, img_dim, img_size = 8, 3, 28
  92. hidden_dim = 16
  93. num_classes = 10
  94. # Make an input data randomly
  95. x = torch.randn(bs, img_dim, img_size, img_size)
  96. # Build a MLP model
  97. model = ConvNet(img_size = img_size,
  98. in_dim = img_dim,
  99. hidden_dim = hidden_dim,
  100. num_classes = num_classes,
  101. act_type = 'relu',
  102. norm_type = 'bn',
  103. depthwise = False,
  104. use_adavgpool = False)
  105. # Inference
  106. output = model(x)
  107. print(output.shape)