backbone.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import torch.nn as nn
  2. import torchvision
  3. from torchvision.models._utils import IntermediateLayerGetter
  4. from torchvision.models.resnet import (ResNet18_Weights,
  5. ResNet34_Weights,
  6. ResNet50_Weights,
  7. ResNet101_Weights)
  8. from .norm import FrozenBatchNorm2d
  9. # IN1K pretrained weights
  10. pretrained_urls = {
  11. # ResNet series
  12. 'resnet18': ResNet18_Weights,
  13. 'resnet34': ResNet34_Weights,
  14. 'resnet50': ResNet50_Weights,
  15. 'resnet101': ResNet101_Weights,
  16. }
  17. # ----------------- Model functions -----------------
  18. ## Build backbone network
  19. def build_backbone(cfg, pretrained):
  20. print('==============================')
  21. print('Backbone: {}'.format(cfg.backbone))
  22. # ResNet
  23. if 'resnet' in cfg.backbone:
  24. pretrained_weight = cfg.pretrained_weight if pretrained else None
  25. model = build_resnet(cfg, pretrained_weight)
  26. else:
  27. raise NotImplementedError("Unknown backbone: <>.".format(cfg.backbone))
  28. return model
  29. # ----------------- ResNet Backbone -----------------
  30. class ResNet(nn.Module):
  31. """ResNet backbone with frozen BatchNorm."""
  32. def __init__(self,
  33. name: str,
  34. norm_type: str,
  35. pretrained_weights: str = "imagenet1k_v1",
  36. freeze_at: int = -1,
  37. freeze_stem_only: bool = False):
  38. super().__init__()
  39. # Pretrained
  40. assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
  41. if pretrained_weights is not None:
  42. if name in ('resnet18', 'resnet34'):
  43. pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
  44. else:
  45. if pretrained_weights == "imagenet1k_v1":
  46. pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
  47. else:
  48. pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
  49. else:
  50. pretrained_weights = None
  51. print('- Backbone pretrained weight: ', pretrained_weights)
  52. # Norm layer
  53. print("- Norm layer of backbone: {}".format(norm_type))
  54. if norm_type == 'BN':
  55. norm_layer = nn.BatchNorm2d
  56. elif norm_type == 'FrozeBN':
  57. norm_layer = FrozenBatchNorm2d
  58. # Backbone
  59. backbone = getattr(torchvision.models, name)(norm_layer=norm_layer, weights=pretrained_weights)
  60. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  61. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  62. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  63. # Freeze
  64. print("- Freeze at: {}".format(freeze_at))
  65. if freeze_at >= 0:
  66. for name, parameter in backbone.named_parameters():
  67. if freeze_stem_only:
  68. print("- Freeze stem layer only")
  69. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  70. parameter.requires_grad_(False)
  71. else:
  72. print("- Freeze stem layer only + layer1")
  73. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  74. parameter.requires_grad_(False)
  75. def forward(self, x):
  76. xs = self.body(x)
  77. fmp_list = []
  78. for name, fmp in xs.items():
  79. fmp_list.append(fmp)
  80. return fmp_list
  81. def build_resnet(cfg, pretrained_weight=None):
  82. # ResNet series
  83. backbone = ResNet(cfg.backbone,
  84. cfg.backbone_norm,
  85. pretrained_weight,
  86. cfg.freeze_at,
  87. cfg.freeze_stem_only)
  88. return backbone