resnet.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Backbone modules.
  4. """
  5. import torch
  6. import torchvision
  7. from torch import nn
  8. from torchvision.models._utils import IntermediateLayerGetter
  9. from torchvision.models.resnet import (ResNet18_Weights,
  10. ResNet34_Weights,
  11. ResNet50_Weights,
  12. ResNet101_Weights)
  13. model_urls = {
  14. # IN1K-Cls pretrained weights
  15. 'resnet18': ResNet18_Weights,
  16. 'resnet34': ResNet34_Weights,
  17. 'resnet50': ResNet50_Weights,
  18. 'resnet101': ResNet101_Weights,
  19. }
  20. # Frozen BatchNormazlizarion
  21. class FrozenBatchNorm2d(torch.nn.Module):
  22. """
  23. BatchNorm2d where the batch statistics and the affine parameters are fixed.
  24. Copy-paste from torchvision.misc.ops with added eps before rqsrt,
  25. without which any other models than torchvision.models.resnet[18,34,50,101]
  26. produce nans.
  27. """
  28. def __init__(self, n):
  29. super(FrozenBatchNorm2d, self).__init__()
  30. self.register_buffer("weight", torch.ones(n))
  31. self.register_buffer("bias", torch.zeros(n))
  32. self.register_buffer("running_mean", torch.zeros(n))
  33. self.register_buffer("running_var", torch.ones(n))
  34. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  35. missing_keys, unexpected_keys, error_msgs):
  36. num_batches_tracked_key = prefix + 'num_batches_tracked'
  37. if num_batches_tracked_key in state_dict:
  38. del state_dict[num_batches_tracked_key]
  39. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  40. state_dict, prefix, local_metadata, strict,
  41. missing_keys, unexpected_keys, error_msgs)
  42. def forward(self, x):
  43. # move reshapes to the beginning
  44. # to make it fuser-friendly
  45. w = self.weight.reshape(1, -1, 1, 1)
  46. b = self.bias.reshape(1, -1, 1, 1)
  47. rv = self.running_var.reshape(1, -1, 1, 1)
  48. rm = self.running_mean.reshape(1, -1, 1, 1)
  49. eps = 1e-5
  50. scale = w * (rv + eps).rsqrt()
  51. bias = b - rm * scale
  52. return x * scale + bias
  53. # -------------------- ResNet series --------------------
  54. class ResNet(nn.Module):
  55. """Standard ResNet backbone."""
  56. def __init__(self,
  57. name :str = "resnet50",
  58. res5_dilation :bool = False,
  59. norm_type :str = "BN",
  60. freeze_at :int = 0,
  61. use_pretrained :bool = False):
  62. super().__init__()
  63. # Pretrained
  64. if use_pretrained:
  65. pretrained_weights = model_urls[name].IMAGENET1K_V1
  66. else:
  67. pretrained_weights = None
  68. # Norm layer
  69. print("- Norm layer of backbone: {}".format(norm_type))
  70. if norm_type == 'BN':
  71. norm_layer = nn.BatchNorm2d
  72. elif norm_type == 'FrozeBN':
  73. norm_layer = FrozenBatchNorm2d
  74. else:
  75. raise NotImplementedError("Unknown norm type: {}".format(norm_type))
  76. # Backbone
  77. backbone = getattr(torchvision.models, name)(
  78. replace_stride_with_dilation=[False, False, res5_dilation],
  79. norm_layer=norm_layer, weights=pretrained_weights)
  80. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  81. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  82. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  83. # Freeze
  84. print("- Freeze at {}".format(freeze_at))
  85. if freeze_at >= 0:
  86. for name, parameter in backbone.named_parameters():
  87. if freeze_at == 0: # Only freeze stem layer
  88. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  89. parameter.requires_grad_(False)
  90. elif freeze_at == 1: # Freeze stem layer + layer1
  91. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  92. parameter.requires_grad_(False)
  93. elif freeze_at == 2: # Freeze stem layer + layer1 + layer2
  94. if 'layer3' not in name and 'layer4' not in name:
  95. parameter.requires_grad_(False)
  96. elif freeze_at == 3: # Freeze stem layer + layer1 + layer2 + layer3
  97. if 'layer4' not in name:
  98. parameter.requires_grad_(False)
  99. else: # Freeze all resnet's layers
  100. parameter.requires_grad_(False)
  101. def forward(self, x):
  102. xs = self.body(x)
  103. fmp_list = []
  104. for name, fmp in xs.items():
  105. fmp_list.append(fmp)
  106. return fmp_list
  107. # build backbone
  108. def build_resnet(cfg):
  109. # ResNet series
  110. backbone = ResNet(
  111. name = cfg.backbone,
  112. res5_dilation = cfg.res5_dilation,
  113. norm_type = cfg.bk_norm,
  114. use_pretrained = cfg.use_pretrained,
  115. freeze_at = cfg.freeze_at)
  116. return backbone, backbone.feat_dims
  117. if __name__ == '__main__':
  118. class FcosBaseConfig(object):
  119. def __init__(self):
  120. self.backbone = "resnet18"
  121. self.bk_norm = "FrozeBN"
  122. self.res5_dilation = False
  123. self.use_pretrained = True
  124. self.freeze_at = 0
  125. cfg = FcosBaseConfig()
  126. model, feat_dim = build_resnet(cfg)
  127. print(feat_dim)
  128. x = torch.randn(2, 3, 320, 320)
  129. output = model(x)
  130. for y in output:
  131. print(y.size())