resnet.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Backbone modules.
  4. """
  5. import torch
  6. import torchvision
  7. from torch import nn
  8. from torchvision.models._utils import IntermediateLayerGetter
  9. from torchvision.models.resnet import (ResNet18_Weights,
  10. ResNet34_Weights,
  11. ResNet50_Weights,
  12. ResNet101_Weights)
  13. model_urls = {
  14. # IN1K-Cls pretrained weights
  15. 'resnet18': ResNet18_Weights,
  16. 'resnet34': ResNet34_Weights,
  17. 'resnet50': ResNet50_Weights,
  18. 'resnet101': ResNet101_Weights,
  19. }
  20. spark_model_urls = {
  21. # SparK's IN1K-MAE pretrained weights
  22. 'spark_resnet18': None,
  23. 'spark_resnet34': None,
  24. 'spark_resnet50': "https://github.com/yjh0410/RT-ODLab/releases/download/backbone_weight/resnet50_in1k_spark_pretrained_timm_style.pth",
  25. 'spark_resnet101': None,
  26. }
  27. # Frozen BatchNormazlizarion
  28. class FrozenBatchNorm2d(torch.nn.Module):
  29. """
  30. BatchNorm2d where the batch statistics and the affine parameters are fixed.
  31. Copy-paste from torchvision.misc.ops with added eps before rqsrt,
  32. without which any other models than torchvision.models.resnet[18,34,50,101]
  33. produce nans.
  34. """
  35. def __init__(self, n):
  36. super(FrozenBatchNorm2d, self).__init__()
  37. self.register_buffer("weight", torch.ones(n))
  38. self.register_buffer("bias", torch.zeros(n))
  39. self.register_buffer("running_mean", torch.zeros(n))
  40. self.register_buffer("running_var", torch.ones(n))
  41. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  42. missing_keys, unexpected_keys, error_msgs):
  43. num_batches_tracked_key = prefix + 'num_batches_tracked'
  44. if num_batches_tracked_key in state_dict:
  45. del state_dict[num_batches_tracked_key]
  46. super(FrozenBatchNorm2d, self)._load_from_state_dict(
  47. state_dict, prefix, local_metadata, strict,
  48. missing_keys, unexpected_keys, error_msgs)
  49. def forward(self, x):
  50. # move reshapes to the beginning
  51. # to make it fuser-friendly
  52. w = self.weight.reshape(1, -1, 1, 1)
  53. b = self.bias.reshape(1, -1, 1, 1)
  54. rv = self.running_var.reshape(1, -1, 1, 1)
  55. rm = self.running_mean.reshape(1, -1, 1, 1)
  56. eps = 1e-5
  57. scale = w * (rv + eps).rsqrt()
  58. bias = b - rm * scale
  59. return x * scale + bias
  60. # -------------------- ResNet series --------------------
  61. class ResNet(nn.Module):
  62. """Standard ResNet backbone."""
  63. def __init__(self,
  64. name :str = "resnet50",
  65. res5_dilation :bool = False,
  66. norm_type :str = "BN",
  67. freeze_at :int = 0,
  68. pretrained_weights :str = "imagenet1k_v1"):
  69. super().__init__()
  70. # Pretrained
  71. assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
  72. if pretrained_weights is not None:
  73. if name in ('resnet18', 'resnet34'):
  74. pretrained_weights = model_urls[name].IMAGENET1K_V1
  75. else:
  76. if pretrained_weights == "imagenet1k_v1":
  77. pretrained_weights = model_urls[name].IMAGENET1K_V1
  78. else:
  79. pretrained_weights = model_urls[name].IMAGENET1K_V2
  80. else:
  81. pretrained_weights = None
  82. print('- Backbone pretrained weight: ', pretrained_weights)
  83. # Norm layer
  84. print("- Norm layer of backbone: {}".format(norm_type))
  85. if norm_type == 'BN':
  86. norm_layer = nn.BatchNorm2d
  87. elif norm_type == 'FrozeBN':
  88. norm_layer = FrozenBatchNorm2d
  89. else:
  90. raise NotImplementedError("Unknown norm type: {}".format(norm_type))
  91. # Backbone
  92. backbone = getattr(torchvision.models, name)(
  93. replace_stride_with_dilation=[False, False, res5_dilation],
  94. norm_layer=norm_layer, weights=pretrained_weights)
  95. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  96. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  97. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  98. # Freeze
  99. print("- Freeze at {}".format(freeze_at))
  100. if freeze_at >= 0:
  101. for name, parameter in backbone.named_parameters():
  102. if freeze_at == 0: # Only freeze stem layer
  103. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  104. parameter.requires_grad_(False)
  105. elif freeze_at == 1: # Freeze stem layer + layer1
  106. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  107. parameter.requires_grad_(False)
  108. elif freeze_at == 2: # Freeze stem layer + layer1 + layer2
  109. if 'layer3' not in name and 'layer4' not in name:
  110. parameter.requires_grad_(False)
  111. elif freeze_at == 3: # Freeze stem layer + layer1 + layer2 + layer3
  112. if 'layer4' not in name:
  113. parameter.requires_grad_(False)
  114. else: # Freeze all resnet's layers
  115. parameter.requires_grad_(False)
  116. def forward(self, x):
  117. xs = self.body(x)
  118. fmp_list = []
  119. for name, fmp in xs.items():
  120. fmp_list.append(fmp)
  121. return fmp_list
  122. class SparkResNet(nn.Module):
  123. """ResNet backbone with SparK pretrained."""
  124. def __init__(self,
  125. name :str = "resnet50",
  126. res5_dilation :bool = False,
  127. norm_type :str = "BN",
  128. freeze_at :int = 0,
  129. pretrained :bool = False):
  130. super().__init__()
  131. # Norm layer
  132. print("- Norm layer of backbone: {}".format(norm_type))
  133. if norm_type == 'BN':
  134. norm_layer = nn.BatchNorm2d
  135. elif norm_type == 'FrozeBN':
  136. norm_layer = FrozenBatchNorm2d
  137. else:
  138. raise NotImplementedError("Unknown norm type: {}".format(norm_type))
  139. # Backbone
  140. backbone = getattr(torchvision.models, name)(
  141. replace_stride_with_dilation=[False, False, res5_dilation], norm_layer=norm_layer)
  142. return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
  143. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  144. self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
  145. # Load pretrained
  146. if pretrained:
  147. self.load_pretrained(name)
  148. # Freeze
  149. print("- Freeze at {}".format(freeze_at))
  150. if freeze_at >= 0:
  151. for name, parameter in backbone.named_parameters():
  152. if freeze_at == 0: # Only freeze stem layer
  153. if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  154. parameter.requires_grad_(False)
  155. elif freeze_at == 1: # Freeze stem layer + layer1
  156. if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
  157. parameter.requires_grad_(False)
  158. elif freeze_at == 2: # Freeze stem layer + layer1 + layer2
  159. if 'layer3' not in name and 'layer4' not in name:
  160. parameter.requires_grad_(False)
  161. elif freeze_at == 3: # Freeze stem layer + layer1 + layer2 + layer3
  162. if 'layer4' not in name:
  163. parameter.requires_grad_(False)
  164. else: # Freeze all resnet's layers
  165. parameter.requires_grad_(False)
  166. def load_pretrained(self, name):
  167. url = spark_model_urls["spark_" + name]
  168. if url is not None:
  169. print('Loading backbone pretrained weight from : {}'.format(url))
  170. # checkpoint state dict
  171. checkpoint_state_dict = torch.hub.load_state_dict_from_url(
  172. url=url, map_location="cpu", check_hash=True)
  173. # model state dict
  174. model_state_dict = self.body.state_dict()
  175. # check
  176. for k in list(checkpoint_state_dict.keys()):
  177. if k in model_state_dict:
  178. shape_model = tuple(model_state_dict[k].shape)
  179. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  180. if shape_model != shape_checkpoint:
  181. checkpoint_state_dict.pop(k)
  182. else:
  183. checkpoint_state_dict.pop(k)
  184. print('Unused key: ', k)
  185. # load the weight
  186. self.body.load_state_dict(checkpoint_state_dict)
  187. else:
  188. print('No backbone pretrained for {}.'.format(name))
  189. def forward(self, x):
  190. xs = self.body(x)
  191. fmp_list = []
  192. for name, fmp in xs.items():
  193. fmp_list.append(fmp)
  194. return fmp_list
  195. # build backbone
  196. def build_resnet(cfg):
  197. # ResNet series
  198. if cfg['pretrained_weight'] in spark_model_urls.keys():
  199. backbone = SparkResNet(
  200. name = cfg['backbone'],
  201. res5_dilation = cfg['res5_dilation'],
  202. norm_type = cfg['backbone_norm'],
  203. pretrained = cfg['pretrained'],
  204. freeze_at = cfg['freeze_at'])
  205. else:
  206. backbone = ResNet(
  207. name = cfg['backbone'],
  208. res5_dilation = cfg['res5_dilation'],
  209. norm_type = cfg['backbone_norm'],
  210. pretrained_weights = cfg['pretrained_weight'],
  211. freeze_at = cfg['freeze_at'])
  212. return backbone, backbone.feat_dims
  213. if __name__ == '__main__':
  214. cfg = {
  215. 'backbone': 'resnet50',
  216. 'backbone_norm': 'FrozeBN',
  217. 'pretrained_weight': 'imagenet1k_v1',
  218. 'res5_dilation': False,
  219. 'freeze_at': 0,
  220. }
  221. model, feat_dim = build_resnet(cfg)
  222. print(feat_dim)
  223. x = torch.randn(2, 3, 320, 320)
  224. output = model(x)
  225. for y in output:
  226. print(y.size())