resnet.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule, PlainResBlock, BottleneckResBlock
  5. except:
  6. from modules import ConvModule, PlainResBlock, BottleneckResBlock
  7. class ResNet(nn.Module):
  8. def __init__(self,
  9. in_dim,
  10. block,
  11. expansion = 1.0,
  12. num_blocks = [2, 2, 2, 2],
  13. num_classes = 1000,
  14. ) -> None:
  15. super().__init__()
  16. # ----------- Basic parameters -----------
  17. self.expansion = expansion
  18. self.num_blocks = num_blocks
  19. self.feat_dims = [64, # C2 level
  20. round(64 * expansion), # C2 level
  21. round(128 * expansion), # C3 level
  22. round(256 * expansion), # C4 level
  23. round(512 * expansion), # C5 level
  24. ]
  25. # ----------- Model parameters -----------
  26. ## Backbone
  27. self.layer_1 = nn.Sequential(
  28. ConvModule(in_dim, self.feat_dims[0],
  29. kernel_size=7, padding=3, stride=2,
  30. act_type='relu', norm_type='bn', depthwise=False),
  31. nn.MaxPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))
  32. )
  33. self.layer_2 = self.make_layer(block, self.feat_dims[0], self.feat_dims[1], depth=num_blocks[0], downsample=False)
  34. self.layer_3 = self.make_layer(block, self.feat_dims[1], self.feat_dims[2], depth=num_blocks[1], downsample=True)
  35. self.layer_4 = self.make_layer(block, self.feat_dims[2], self.feat_dims[3], depth=num_blocks[2], downsample=True)
  36. self.layer_5 = self.make_layer(block, self.feat_dims[3], self.feat_dims[4], depth=num_blocks[3], downsample=True)
  37. ## Classifier
  38. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  39. self.fc = nn.Linear(self.feat_dims[4] , num_classes)
  40. def make_layer(self, block, in_dim, out_dim, depth=1, downsample=False):
  41. stage_blocks = []
  42. for i in range(depth):
  43. if i == 0:
  44. stride = 2 if downsample else 1
  45. inter_dim = round(out_dim / self.expansion)
  46. stage_blocks.append(block(in_dim, inter_dim, out_dim, stride))
  47. else:
  48. stride = 1
  49. inter_dim = round(out_dim / self.expansion)
  50. stage_blocks.append(block(out_dim, inter_dim, out_dim, stride))
  51. layers = nn.Sequential(*stage_blocks)
  52. return layers
  53. def forward(self, x):
  54. x = self.layer_1(x)
  55. x = self.layer_2(x)
  56. x = self.layer_3(x)
  57. x = self.layer_4(x)
  58. x = self.layer_5(x)
  59. x = self.avgpool(x)
  60. x = x.flatten(1)
  61. x = self.fc(x)
  62. return x
  63. def build_resnet(model_name='resnet18', img_dim=3):
  64. if model_name == 'resnet18':
  65. model = ResNet(in_dim=img_dim,
  66. block=PlainResBlock,
  67. expansion=1.0,
  68. num_blocks=[2, 2, 2, 2],
  69. )
  70. elif model_name == 'resnet50':
  71. model = ResNet(in_dim=img_dim,
  72. block=BottleneckResBlock,
  73. expansion=4.0,
  74. num_blocks=[3, 4, 6, 3],
  75. )
  76. else:
  77. raise NotImplementedError("Unknown resnet: {}".format(model_name))
  78. return model
  79. if __name__=='__main__':
  80. import time
  81. # 构建ResNet模型
  82. model = build_resnet(model_name='resnet18')
  83. # 打印模型结构
  84. print(model)
  85. # 随即成生数据
  86. x = torch.randn(1, 3, 224, 224)
  87. # 模型前向推理
  88. t0 = time.time()
  89. output = model(x)
  90. t1 = time.time()
  91. print('Time: ', t1 - t0)