backbone.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import torch.nn as nn
  2. import torchvision
  3. from torchvision.models._utils import IntermediateLayerGetter
  4. from torchvision.models import resnet
  5. try:
  6. from .norm import FrozenBatchNorm2d
  7. except:
  8. from norm import FrozenBatchNorm2d
  9. # IN1K pretrained weights
  10. pretrained_urls = {
  11. # ResNet series
  12. 'resnet18': resnet.ResNet18_Weights,
  13. 'resnet34': resnet.ResNet34_Weights,
  14. 'resnet50': resnet.ResNet50_Weights,
  15. 'resnet101': resnet.ResNet101_Weights,
  16. }
  17. # ----------------- Model functions -----------------
  18. ## Build backbone network
  19. def build_backbone(cfg, pretrained):
  20. print('==============================')
  21. print('Backbone: {}'.format(cfg.backbone))
  22. # ResNet
  23. if 'resnet' in cfg.backbone:
  24. pretrained_weight = cfg.pretrained_weight if pretrained else None
  25. model = build_resnet(cfg, pretrained_weight)
  26. else:
  27. raise NotImplementedError("Unknown backbone: <>.".format(cfg.backbone))
  28. return model
  29. # ----------------- ResNet Backbone -----------------
  30. class ResNet(nn.Module):
  31. """ResNet backbone with frozen BatchNorm."""
  32. def __init__(self,
  33. name: str,
  34. norm_type: str,
  35. pretrained_weights: str = "imagenet1k_v1",
  36. freeze_at: int = -1,
  37. freeze_stem_only: bool = False):
  38. super().__init__()
  39. # Pretrained
  40. assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
  41. if pretrained_weights is not None:
  42. if name == "resnet18":
  43. pretrained_weights = resnet.ResNet18_Weights.IMAGENET1K_V1
  44. elif name == "resnet34":
  45. pretrained_weights = resnet.ResNet34_Weights.IMAGENET1K_V1
  46. elif name == "resnet50" and pretrained_weights == "imagenet1k_v1":
  47. pretrained_weights = resnet.ResNet50_Weights.IMAGENET1K_V1
  48. elif name == "resnet50" and pretrained_weights == "imagenet1k_v2":
  49. pretrained_weights = resnet.ResNet50_Weights.IMAGENET1K_V2
  50. elif name == "resnet101" and pretrained_weights == "imagenet1k_v1":
  51. pretrained_weights = resnet.ResNet101_Weights.IMAGENET1K_V1
  52. elif name == "resnet101" and pretrained_weights == "imagenet1k_v2":
  53. pretrained_weights = resnet.ResNet101_Weights.IMAGENET1K_V2
  54. else:
  55. pretrained_weights = None
  56. print('- Backbone pretrained weight: ', pretrained_weights)
  57. # Norm layer
  58. print("- Norm layer of backbone: {}".format(norm_type))
  59. if norm_type == 'BN':
  60. norm_layer = nn.BatchNorm2d
  61. elif norm_type == 'FrozeBN':
  62. norm_layer = FrozenBatchNorm2d
  63. # Backbone
  64. backbone = getattr(torchvision.models, name)(norm_layer=norm_layer, weights=pretrained_weights)
  65. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  66. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  67. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  68. # Freeze
  69. print("- Freeze at: {}".format(freeze_at))
  70. if freeze_at >= 0:
  71. for name, parameter in backbone.named_parameters():
  72. if freeze_stem_only:
  73. print("- Freeze stem layer only")
  74. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  75. parameter.requires_grad_(False)
  76. else:
  77. print("- Freeze stem layer only + layer1")
  78. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  79. parameter.requires_grad_(False)
  80. def forward(self, x):
  81. xs = self.body(x)
  82. fmp_list = []
  83. for name, fmp in xs.items():
  84. fmp_list.append(fmp)
  85. return fmp_list
  86. def build_resnet(cfg, pretrained_weight=None):
  87. # ResNet series
  88. backbone = ResNet(cfg.backbone,
  89. cfg.backbone_norm,
  90. pretrained_weight,
  91. cfg.freeze_at,
  92. cfg.freeze_stem_only)
  93. return backbone