backbone.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torchvision.models._utils import IntermediateLayerGetter
  5. from torchvision.models.resnet import (ResNet18_Weights,
  6. ResNet34_Weights,
  7. ResNet50_Weights,
  8. ResNet101_Weights)
  9. try:
  10. from .basic import FrozenBatchNorm2d
  11. except:
  12. from basic import FrozenBatchNorm2d
  13. # IN1K pretrained weights
  14. pretrained_urls = {
  15. # ResNet series
  16. 'resnet18': ResNet18_Weights,
  17. 'resnet34': ResNet34_Weights,
  18. 'resnet50': ResNet50_Weights,
  19. 'resnet101': ResNet101_Weights,
  20. # ShuffleNet series
  21. }
  22. # ----------------- Model functions -----------------
  23. ## Build backbone network
  24. def build_backbone(cfg, pretrained):
  25. print('==============================')
  26. print('Backbone: {}'.format(cfg['backbone']))
  27. # ResNet
  28. if 'resnet' in cfg['backbone']:
  29. pretrained_weight = cfg['pretrained_weight'] if pretrained else None
  30. model, feats = build_resnet(cfg, pretrained_weight)
  31. elif 'svnetv2' in cfg['backbone']:
  32. pretrained_weight = cfg['pretrained_weight'] if pretrained else None
  33. model, feats = build_scnetv2(cfg, pretrained_weight)
  34. else:
  35. raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
  36. return model, feats
  37. # ----------------- ResNet Backbone -----------------
  38. class ResNet(nn.Module):
  39. """ResNet backbone with frozen BatchNorm."""
  40. def __init__(self,
  41. name: str,
  42. norm_type: str,
  43. pretrained_weights: str = "imagenet1k_v1",
  44. freeze_at: int = -1,
  45. freeze_stem_only: bool = False):
  46. super().__init__()
  47. # Pretrained
  48. assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
  49. if pretrained_weights is not None:
  50. if name in ('resnet18', 'resnet34'):
  51. pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
  52. else:
  53. if pretrained_weights == "imagenet1k_v1":
  54. pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
  55. else:
  56. pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
  57. else:
  58. pretrained_weights = None
  59. print('ImageNet pretrained weight: ', pretrained_weights)
  60. # Norm layer
  61. if norm_type == 'BN':
  62. norm_layer = nn.BatchNorm2d
  63. elif norm_type == 'FrozeBN':
  64. norm_layer = FrozenBatchNorm2d
  65. # Backbone
  66. backbone = getattr(torchvision.models, name)(norm_layer=norm_layer, weights=pretrained_weights)
  67. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  68. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  69. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  70. # Freeze
  71. if freeze_at >= 0:
  72. for name, parameter in backbone.named_parameters():
  73. if freeze_stem_only:
  74. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  75. parameter.requires_grad_(False)
  76. else:
  77. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  78. parameter.requires_grad_(False)
  79. def forward(self, x):
  80. xs = self.body(x)
  81. fmp_list = []
  82. for name, fmp in xs.items():
  83. fmp_list.append(fmp)
  84. return fmp_list
  85. def build_resnet(cfg, pretrained_weight=None):
  86. # ResNet series
  87. backbone = ResNet(cfg['backbone'],
  88. cfg['backbone_norm'],
  89. pretrained_weight,
  90. cfg['freeze_at'],
  91. cfg['freeze_stem_only'])
  92. return backbone, backbone.feat_dims
  93. # ----------------- ShuffleNet Backbone -----------------
  94. ## TODO: Add shufflenet-v2
  95. class ShuffleNetv2:
  96. pass
  97. def build_scnetv2(cfg, pretrained_weight=None):
  98. return
  99. if __name__ == '__main__':
  100. cfg = {
  101. 'backbone': 'resnet18',
  102. 'backbone_norm': 'BN',
  103. 'pretrained': True,
  104. 'freeze_at': -1,
  105. 'freeze_stem_only': True,
  106. 'pretrained_weight': 'imagenet1k_v1',
  107. }
  108. model, feat_dim = build_backbone(cfg, cfg['pretrained'])
  109. print(feat_dim)
  110. x = torch.randn(2, 3, 320, 320)
  111. output = model(x)
  112. for y in output:
  113. print(y.size())
  114. # for n, p in model.named_parameters():
  115. # print(n.split(".")[-1])