backbone.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torchvision.models._utils import IntermediateLayerGetter
  5. from torchvision.models.resnet import (ResNet18_Weights,
  6. ResNet34_Weights,
  7. ResNet50_Weights,
  8. ResNet101_Weights)
  9. try:
  10. from .basic import FrozenBatchNorm2d
  11. except:
  12. from basic import FrozenBatchNorm2d
  13. # IN1K pretrained weights
  14. pretrained_urls = {
  15. # ResNet series
  16. 'resnet18': ResNet18_Weights,
  17. 'resnet34': ResNet34_Weights,
  18. 'resnet50': ResNet50_Weights,
  19. 'resnet101': ResNet101_Weights,
  20. # ShuffleNet series
  21. }
  22. # ----------------- Model functions -----------------
  23. ## Build backbone network
  24. def build_backbone(cfg, pretrained):
  25. print('==============================')
  26. print('Backbone: {}'.format(cfg['backbone']))
  27. # ResNet
  28. if 'resnet' in cfg['backbone']:
  29. pretrained_weight = cfg['pretrained_weight'] if pretrained else None
  30. model, feats = build_resnet(cfg, pretrained_weight)
  31. elif 'svnetv2' in cfg['backbone']:
  32. pretrained_weight = cfg['pretrained_weight'] if pretrained else None
  33. model, feats = build_scnetv2(cfg, pretrained_weight)
  34. else:
  35. raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
  36. return model, feats
  37. # ----------------- ResNet Backbone -----------------
  38. class ResNet(nn.Module):
  39. """ResNet backbone with frozen BatchNorm."""
  40. def __init__(self,
  41. name: str,
  42. res5_dilation: bool,
  43. norm_type: str,
  44. pretrained_weights: str = "imagenet1k_v1",
  45. freeze_stem_only: bool = False):
  46. super().__init__()
  47. # Pretrained
  48. assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
  49. if pretrained_weights is not None:
  50. if name in ('resnet18', 'resnet34'):
  51. pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
  52. else:
  53. if pretrained_weights == "imagenet1k_v1":
  54. pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
  55. else:
  56. pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
  57. else:
  58. pretrained_weights = None
  59. print('ImageNet pretrained weight: ', pretrained_weights)
  60. # Norm layer
  61. if norm_type == 'BN':
  62. norm_layer = nn.BatchNorm2d
  63. elif norm_type == 'FrozeBN':
  64. norm_layer = FrozenBatchNorm2d
  65. # Backbone
  66. backbone = getattr(torchvision.models, name)(
  67. replace_stride_with_dilation=[False, False, res5_dilation],
  68. norm_layer=norm_layer, weights=pretrained_weights)
  69. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  70. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  71. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  72. # Freeze
  73. for name, parameter in backbone.named_parameters():
  74. if freeze_stem_only:
  75. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  76. parameter.requires_grad_(False)
  77. else:
  78. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  79. parameter.requires_grad_(False)
  80. def forward(self, x):
  81. xs = self.body(x)
  82. fmp_list = []
  83. for name, fmp in xs.items():
  84. fmp_list.append(fmp)
  85. return fmp_list
  86. def build_resnet(cfg, pretrained_weight=None):
  87. # ResNet series
  88. backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight, cfg['freeze_stem_only'])
  89. return backbone, backbone.feat_dims
  90. # ----------------- ShuffleNet Backbone -----------------
  91. ## TODO: Add shufflenet-v2
  92. class ShuffleNetv2:
  93. pass
  94. def build_scnetv2(cfg, pretrained_weight=None):
  95. return
  96. if __name__ == '__main__':
  97. cfg = {
  98. 'backbone': 'resnet18',
  99. 'backbone_norm': 'BN',
  100. 'res5_dilation': False,
  101. 'pretrained': True,
  102. 'pretrained_weight': 'imagenet1k_v1',
  103. }
  104. model, feat_dim = build_backbone(cfg, cfg['pretrained'])
  105. print(feat_dim)
  106. x = torch.randn(2, 3, 320, 320)
  107. output = model(x)
  108. for y in output:
  109. print(y.size())
  110. for n, p in model.named_parameters():
  111. print(n.split(".")[-1])