backbone.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torchvision.models._utils import IntermediateLayerGetter
  5. from torchvision.models.resnet import (ResNet18_Weights,
  6. ResNet34_Weights,
  7. ResNet50_Weights,
  8. ResNet101_Weights)
  9. try:
  10. from .basic import FrozenBatchNorm2d
  11. except:
  12. from basic import FrozenBatchNorm2d
  13. # IN1K pretrained weights
  14. pretrained_urls = {
  15. # ResNet series
  16. 'resnet18': ResNet18_Weights,
  17. 'resnet34': ResNet34_Weights,
  18. 'resnet50': ResNet50_Weights,
  19. 'resnet101': ResNet101_Weights,
  20. # ShuffleNet series
  21. }
  22. # ----------------- Model functions -----------------
  23. ## Build backbone network
  24. def build_backbone(cfg, pretrained):
  25. print('==============================')
  26. print('Backbone: {}'.format(cfg['backbone']))
  27. # ResNet
  28. if 'resnet' in cfg['backbone']:
  29. pretrained_weight = cfg['pretrained_weight'] if pretrained else None
  30. model, feats = build_resnet(cfg, pretrained_weight)
  31. elif 'svnetv2' in cfg['backbone']:
  32. pretrained_weight = cfg['pretrained_weight'] if pretrained else None
  33. model, feats = build_scnetv2(cfg, pretrained_weight)
  34. else:
  35. raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
  36. return model, feats
  37. # ----------------- ResNet Backbone -----------------
  38. class ResNet(nn.Module):
  39. """ResNet backbone with frozen BatchNorm."""
  40. def __init__(self, name: str, res5_dilation: bool, norm_type: str, pretrained_weights: str = "imagenet1k_v1"):
  41. super().__init__()
  42. # Pretrained
  43. assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
  44. if pretrained_weights is not None:
  45. if name in ('resnet18', 'resnet34'):
  46. pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
  47. else:
  48. if pretrained_weights == "imagenet1k_v1":
  49. pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
  50. else:
  51. pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
  52. else:
  53. pretrained_weights = None
  54. print('ImageNet pretrained weight: ', pretrained_weights)
  55. # Norm layer
  56. if norm_type == 'BN':
  57. norm_layer = nn.BatchNorm2d
  58. elif norm_type == 'FrozeBN':
  59. norm_layer = FrozenBatchNorm2d
  60. # Backbone
  61. backbone = getattr(torchvision.models, name)(
  62. replace_stride_with_dilation=[False, False, res5_dilation],
  63. norm_layer=norm_layer, weights=pretrained_weights)
  64. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  65. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  66. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  67. # Freeze
  68. for name, parameter in backbone.named_parameters():
  69. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  70. parameter.requires_grad_(False)
  71. def forward(self, x):
  72. xs = self.body(x)
  73. fmp_list = []
  74. for name, fmp in xs.items():
  75. fmp_list.append(fmp)
  76. return fmp_list
  77. def build_resnet(cfg, pretrained_weight=None):
  78. # ResNet series
  79. backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight)
  80. return backbone, backbone.feat_dims
  81. # ----------------- ShuffleNet Backbone -----------------
  82. ## TODO: Add shufflenet-v2
  83. class ShuffleNetv2:
  84. pass
  85. def build_scnetv2(cfg, pretrained_weight=None):
  86. return
  87. if __name__ == '__main__':
  88. cfg = {
  89. 'backbone': 'resnet18',
  90. 'backbone_norm': 'BN',
  91. 'res5_dilation': False,
  92. 'pretrained': True,
  93. 'pretrained_weight': 'imagenet1k_v1',
  94. }
  95. model, feat_dim = build_backbone(cfg, cfg['pretrained'])
  96. print(feat_dim)
  97. x = torch.randn(2, 3, 320, 320)
  98. output = model(x)
  99. for y in output:
  100. print(y.size())
  101. for n, p in model.named_parameters():
  102. print(n.split(".")[-1])