backbone.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torchvision.models._utils import IntermediateLayerGetter
  5. try:
  6. from .basic import FrozenBatchNorm2d
  7. except:
  8. from basic import FrozenBatchNorm2d
  9. # IN1K MIM pretrained weights (from SparK: https://github.com/keyu-tian/SparK)
  10. pretrained_urls = {
  11. # ResNet series
  12. 'resnet18': None,
  13. 'resnet34': None,
  14. 'resnet50': "https://github.com/yjh0410/RT-ODLab/releases/download/backbone_weight/resnet50_in1k_spark_pretrained_timm_style.pth",
  15. 'resnet101': None,
  16. # ShuffleNet series
  17. }
  18. # ----------------- Model functions -----------------
  19. ## Build backbone network
  20. def build_backbone(cfg, pretrained=False):
  21. print('==============================')
  22. print('Backbone: {}'.format(cfg['backbone']))
  23. # ResNet
  24. if 'resnet' in cfg['backbone']:
  25. model, feats = build_resnet(cfg, pretrained)
  26. elif 'svnetv2' in cfg['backbone']:
  27. pretrained_weight = cfg['pretrained_weight'] if pretrained else None
  28. model, feats = build_scnetv2(cfg, pretrained_weight)
  29. else:
  30. raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
  31. return model, feats
  32. # ----------------- ResNet Backbone -----------------
  33. class ResNet(nn.Module):
  34. """ResNet backbone with frozen BatchNorm."""
  35. def __init__(self,
  36. name: str,
  37. norm_type: str,
  38. pretrained: bool = False,
  39. freeze_at: int = -1,
  40. freeze_stem_only: bool = False):
  41. super().__init__()
  42. # Pretrained
  43. # Norm layer
  44. if norm_type == 'BN':
  45. norm_layer = nn.BatchNorm2d
  46. elif norm_type == 'FrozeBN':
  47. norm_layer = FrozenBatchNorm2d
  48. # Backbone
  49. backbone = getattr(torchvision.models, name)(norm_layer=norm_layer,)
  50. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  51. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  52. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  53. # Load pretrained
  54. if pretrained:
  55. self.load_pretrained(name)
  56. # Freeze
  57. if freeze_at >= 0:
  58. for name, parameter in backbone.named_parameters():
  59. if freeze_stem_only:
  60. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  61. parameter.requires_grad_(False)
  62. else:
  63. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  64. parameter.requires_grad_(False)
  65. def load_pretrained(self, name):
  66. url = pretrained_urls[name]
  67. if url is not None:
  68. print('Loading pretrained weight from : {}'.format(url))
  69. # checkpoint state dict
  70. checkpoint_state_dict = torch.hub.load_state_dict_from_url(
  71. url=url, map_location="cpu", check_hash=True)
  72. # model state dict
  73. model_state_dict = self.body.state_dict()
  74. # check
  75. for k in list(checkpoint_state_dict.keys()):
  76. if k in model_state_dict:
  77. shape_model = tuple(model_state_dict[k].shape)
  78. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  79. if shape_model != shape_checkpoint:
  80. checkpoint_state_dict.pop(k)
  81. else:
  82. checkpoint_state_dict.pop(k)
  83. print('Unused key: ', k)
  84. # load the weight
  85. self.body.load_state_dict(checkpoint_state_dict)
  86. else:
  87. print('No backbone pretrained for {}.'.format(name))
  88. def forward(self, x):
  89. xs = self.body(x)
  90. fmp_list = []
  91. for name, fmp in xs.items():
  92. fmp_list.append(fmp)
  93. return fmp_list
  94. def build_resnet(cfg, pretrained=False):
  95. # ResNet series
  96. backbone = ResNet(cfg['backbone'],
  97. cfg['backbone_norm'],
  98. pretrained,
  99. cfg['freeze_at'],
  100. cfg['freeze_stem_only'])
  101. return backbone, backbone.feat_dims
  102. # ----------------- ShuffleNet Backbone -----------------
  103. ## TODO: Add shufflenet-v2
  104. class ShuffleNetv2:
  105. pass
  106. def build_scnetv2(cfg, pretrained_weight=None):
  107. return
  108. if __name__ == '__main__':
  109. cfg = {
  110. 'backbone': 'resnet50',
  111. 'backbone_norm': 'FrozeBN',
  112. 'pretrained': True,
  113. 'freeze_at': 0,
  114. 'freeze_stem_only': False,
  115. }
  116. model, feat_dim = build_backbone(cfg, cfg['pretrained'])
  117. model.eval()
  118. print(feat_dim)
  119. x = torch.ones(2, 3, 320, 320)
  120. output = model(x)
  121. for y in output:
  122. print(y.size())
  123. print(output[-1])