backbone.py 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch.nn as nn
  2. import torchvision
  3. from torchvision.models._utils import IntermediateLayerGetter
  4. from torchvision.models import resnet
  5. try:
  6. from .norm import FrozenBatchNorm2d
  7. except:
  8. from norm import FrozenBatchNorm2d
  9. # ----------------- Model functions -----------------
  10. ## Build backbone network
  11. def build_backbone(cfg, pretrained):
  12. print('==============================')
  13. print('Backbone: {}'.format(cfg.backbone))
  14. # ResNet
  15. if 'resnet' in cfg.backbone:
  16. pretrained_weight = cfg.pretrained_weight if pretrained else None
  17. model = build_resnet(cfg, pretrained_weight)
  18. else:
  19. raise NotImplementedError("Unknown backbone: <>.".format(cfg.backbone))
  20. return model
  21. # ----------------- ResNet Backbone -----------------
  22. class ResNet(nn.Module):
  23. """ResNet backbone with frozen BatchNorm."""
  24. def __init__(self,
  25. name: str,
  26. norm_type: str,
  27. pretrained_weights: str = "imagenet1k_v1",
  28. freeze_at: int = -1,
  29. freeze_stem_only: bool = False):
  30. super().__init__()
  31. # Pretrained
  32. assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
  33. if pretrained_weights is not None:
  34. if name == "resnet18":
  35. pretrained_weights = resnet.ResNet18_Weights.IMAGENET1K_V1
  36. elif name == "resnet34":
  37. pretrained_weights = resnet.ResNet34_Weights.IMAGENET1K_V1
  38. elif name == "resnet50" and pretrained_weights == "imagenet1k_v1":
  39. pretrained_weights = resnet.ResNet50_Weights.IMAGENET1K_V1
  40. elif name == "resnet50" and pretrained_weights == "imagenet1k_v2":
  41. pretrained_weights = resnet.ResNet50_Weights.IMAGENET1K_V2
  42. elif name == "resnet101" and pretrained_weights == "imagenet1k_v1":
  43. pretrained_weights = resnet.ResNet101_Weights.IMAGENET1K_V1
  44. elif name == "resnet101" and pretrained_weights == "imagenet1k_v2":
  45. pretrained_weights = resnet.ResNet101_Weights.IMAGENET1K_V2
  46. else:
  47. pretrained_weights = None
  48. print('- Backbone pretrained weight: ', pretrained_weights)
  49. # Norm layer
  50. print("- Norm layer of backbone: {}".format(norm_type))
  51. if norm_type == 'BN':
  52. norm_layer = nn.BatchNorm2d
  53. elif norm_type == 'FrozeBN':
  54. norm_layer = FrozenBatchNorm2d
  55. # Backbone
  56. backbone = getattr(torchvision.models, name)(norm_layer=norm_layer, weights=pretrained_weights)
  57. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  58. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  59. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  60. # Freeze
  61. print("- Freeze at: {}".format(freeze_at))
  62. if freeze_at >= 0:
  63. for name, parameter in backbone.named_parameters():
  64. if freeze_stem_only:
  65. print("- Freeze stem layer only")
  66. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  67. parameter.requires_grad_(False)
  68. else:
  69. print("- Freeze stem layer only + layer1")
  70. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  71. parameter.requires_grad_(False)
  72. def forward(self, x):
  73. xs = self.body(x)
  74. fmp_list = []
  75. for name, fmp in xs.items():
  76. fmp_list.append(fmp)
  77. return fmp_list
  78. def build_resnet(cfg, pretrained_weight=None):
  79. # ResNet series
  80. backbone = ResNet(cfg.backbone,
  81. cfg.backbone_norm,
  82. pretrained_weight,
  83. cfg.freeze_at,
  84. cfg.freeze_stem_only)
  85. return backbone