backbone.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torchvision.models._utils import IntermediateLayerGetter
  5. from torchvision.models.resnet import (ResNet18_Weights,
  6. ResNet34_Weights,
  7. ResNet50_Weights,
  8. ResNet101_Weights)
  9. try:
  10. from .basic import FrozenBatchNorm2d
  11. except:
  12. from basic import FrozenBatchNorm2d
  13. # IN1K pretrained weights
  14. pretrained_urls = {
  15. # ResNet series
  16. 'resnet18': ResNet18_Weights,
  17. 'resnet34': ResNet34_Weights,
  18. 'resnet50': ResNet50_Weights,
  19. 'resnet101': ResNet101_Weights,
  20. # ShuffleNet series
  21. }
  22. # ----------------- Model functions -----------------
  23. ## Build backbone network
  24. def build_backbone(cfg, pretrained):
  25. print('==============================')
  26. print('Backbone: {}'.format(cfg['backbone']))
  27. # ResNet
  28. if 'resnet' in cfg['backbone']:
  29. pretrained_weight = cfg['pretrained_weight'] if pretrained else None
  30. model, feats = build_resnet(cfg, pretrained_weight)
  31. elif 'svnetv2' in cfg['backbone']:
  32. pretrained_weight = cfg['pretrained_weight'] if pretrained else None
  33. model, feats = build_scnetv2(cfg, pretrained_weight)
  34. else:
  35. raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
  36. return model, feats
  37. # ----------------- ResNet Backbone -----------------
  38. class ResNet(nn.Module):
  39. """ResNet backbone with frozen BatchNorm."""
  40. def __init__(self,
  41. name: str,
  42. res5_dilation: bool,
  43. norm_type: str,
  44. pretrained_weights: str = "imagenet1k_v1",
  45. freeze_at: int = -1,
  46. freeze_stem_only: bool = False):
  47. super().__init__()
  48. # Pretrained
  49. assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
  50. if pretrained_weights is not None:
  51. if name in ('resnet18', 'resnet34'):
  52. pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
  53. else:
  54. if pretrained_weights == "imagenet1k_v1":
  55. pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
  56. else:
  57. pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
  58. else:
  59. pretrained_weights = None
  60. print('ImageNet pretrained weight: ', pretrained_weights)
  61. # Norm layer
  62. if norm_type == 'BN':
  63. norm_layer = nn.BatchNorm2d
  64. elif norm_type == 'FrozeBN':
  65. norm_layer = FrozenBatchNorm2d
  66. # Backbone
  67. backbone = getattr(torchvision.models, name)(
  68. replace_stride_with_dilation=[False, False, res5_dilation],
  69. norm_layer=norm_layer, weights=pretrained_weights)
  70. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  71. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  72. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  73. # Freeze
  74. if freeze_at >= 0:
  75. for name, parameter in backbone.named_parameters():
  76. if freeze_stem_only:
  77. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  78. parameter.requires_grad_(False)
  79. else:
  80. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  81. parameter.requires_grad_(False)
  82. def forward(self, x):
  83. xs = self.body(x)
  84. fmp_list = []
  85. for name, fmp in xs.items():
  86. fmp_list.append(fmp)
  87. return fmp_list
  88. def build_resnet(cfg, pretrained_weight=None):
  89. # ResNet series
  90. backbone = ResNet(cfg['backbone'],
  91. cfg['res5_dilation'],
  92. cfg['backbone_norm'],
  93. pretrained_weight,
  94. cfg['freeze_at'],
  95. cfg['freeze_stem_only'])
  96. return backbone, backbone.feat_dims
  97. # ----------------- ShuffleNet Backbone -----------------
  98. ## TODO: Add shufflenet-v2
  99. class ShuffleNetv2:
  100. pass
  101. def build_scnetv2(cfg, pretrained_weight=None):
  102. return
  103. if __name__ == '__main__':
  104. cfg = {
  105. 'backbone': 'resnet18',
  106. 'backbone_norm': 'BN',
  107. 'res5_dilation': False,
  108. 'pretrained': True,
  109. 'freeze_at': -1,
  110. 'freeze_stem_only': True,
  111. 'pretrained_weight': 'imagenet1k_v1',
  112. }
  113. model, feat_dim = build_backbone(cfg, cfg['pretrained'])
  114. print(feat_dim)
  115. x = torch.randn(2, 3, 320, 320)
  116. output = model(x)
  117. for y in output:
  118. print(y.size())
  119. for n, p in model.named_parameters():
  120. print(n.split(".")[-1])