backbone.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torchvision.models._utils import IntermediateLayerGetter
  5. try:
  6. from .basic import FrozenBatchNorm2d
  7. except:
  8. from basic import FrozenBatchNorm2d
  9. # IN1K MIM pretrained weights (from SparK: https://github.com/keyu-tian/SparK)
  10. pretrained_urls = {
  11. # ResNet series
  12. 'resnet18': None,
  13. 'resnet34': None,
  14. 'resnet50': "https://github.com/yjh0410/RT-ODLab/releases/download/backbone_weight/resnet50_in1k_spark_pretrained_timm_style.pth",
  15. 'resnet101': None,
  16. # ShuffleNet series
  17. }
  18. # ----------------- Model functions -----------------
  19. ## Build backbone network
  20. def build_backbone(cfg, pretrained=False):
  21. print('==============================')
  22. print('Backbone: {}'.format(cfg['backbone']))
  23. # ResNet
  24. if 'resnet' in cfg['backbone']:
  25. model, feats = build_resnet(cfg, pretrained)
  26. else:
  27. raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
  28. return model, feats
  29. # ----------------- ResNet Backbone -----------------
  30. class VisionTransformer(nn.Module):
  31. """Vision Transformer."""
  32. def __init__(self,
  33. name: str,
  34. norm_type: str,
  35. pretrained: bool = False,
  36. freeze_at: int = -1,
  37. freeze_stem_only: bool = False):
  38. super().__init__()
  39. # Pretrained
  40. # Norm layer
  41. if norm_type == 'BN':
  42. norm_layer = nn.BatchNorm2d
  43. elif norm_type == 'FrozeBN':
  44. norm_layer = FrozenBatchNorm2d
  45. # Backbone
  46. backbone = getattr(torchvision.models, name)(norm_layer=norm_layer,)
  47. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  48. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  49. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  50. # Load pretrained
  51. if pretrained:
  52. self.load_pretrained(name)
  53. # Freeze
  54. if freeze_at >= 0:
  55. for name, parameter in backbone.named_parameters():
  56. if freeze_stem_only:
  57. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  58. parameter.requires_grad_(False)
  59. else:
  60. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  61. parameter.requires_grad_(False)
  62. def load_pretrained(self, name):
  63. url = pretrained_urls[name]
  64. if url is not None:
  65. print('Loading pretrained weight from : {}'.format(url))
  66. # checkpoint state dict
  67. checkpoint_state_dict = torch.hub.load_state_dict_from_url(
  68. url=url, map_location="cpu", check_hash=True)
  69. # model state dict
  70. model_state_dict = self.body.state_dict()
  71. # check
  72. for k in list(checkpoint_state_dict.keys()):
  73. if k in model_state_dict:
  74. shape_model = tuple(model_state_dict[k].shape)
  75. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  76. if shape_model != shape_checkpoint:
  77. checkpoint_state_dict.pop(k)
  78. else:
  79. checkpoint_state_dict.pop(k)
  80. print('Unused key: ', k)
  81. # load the weight
  82. self.body.load_state_dict(checkpoint_state_dict)
  83. else:
  84. print('No backbone pretrained for {}.'.format(name))
  85. def forward(self, x):
  86. xs = self.body(x)
  87. fmp_list = []
  88. for name, fmp in xs.items():
  89. fmp_list.append(fmp)
  90. return fmp_list
  91. def build_resnet(cfg, pretrained=False):
  92. # ResNet series
  93. backbone = None
  94. return backbone
  95. if __name__ == '__main__':
  96. cfg = {
  97. 'backbone': 'resnet50',
  98. 'backbone_norm': 'FrozeBN',
  99. 'pretrained': True,
  100. 'freeze_at': 0,
  101. 'freeze_stem_only': False,
  102. }
  103. model, feat_dim = build_backbone(cfg, cfg['pretrained'])
  104. model.eval()
  105. print(feat_dim)
  106. x = torch.ones(2, 3, 320, 320)
  107. output = model(x)
  108. for y in output:
  109. print(y.size())
  110. print(output[-1])