| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- """
- Backbone modules.
- """
- import torch
- import torchvision
- from torch import nn
- from torchvision.models._utils import IntermediateLayerGetter
- from torchvision.models.resnet import (ResNet18_Weights,
- ResNet34_Weights,
- ResNet50_Weights,
- ResNet101_Weights)
- model_urls = {
- # IN1K-Cls pretrained weights
- 'resnet18': ResNet18_Weights,
- 'resnet34': ResNet34_Weights,
- 'resnet50': ResNet50_Weights,
- 'resnet101': ResNet101_Weights,
- }
- spark_model_urls = {
- # SparK's IN1K-MAE pretrained weights
- 'spark_resnet18': None,
- 'spark_resnet34': None,
- 'spark_resnet50': "https://github.com/yjh0410/RT-ODLab/releases/download/backbone_weight/resnet50_in1k_spark_pretrained_timm_style.pth",
- 'spark_resnet101': None,
- }
- # Frozen BatchNormazlizarion
- class FrozenBatchNorm2d(torch.nn.Module):
- """
- BatchNorm2d where the batch statistics and the affine parameters are fixed.
- Copy-paste from torchvision.misc.ops with added eps before rqsrt,
- without which any other models than torchvision.models.resnet[18,34,50,101]
- produce nans.
- """
- def __init__(self, n):
- super(FrozenBatchNorm2d, self).__init__()
- self.register_buffer("weight", torch.ones(n))
- self.register_buffer("bias", torch.zeros(n))
- self.register_buffer("running_mean", torch.zeros(n))
- self.register_buffer("running_var", torch.ones(n))
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- num_batches_tracked_key = prefix + 'num_batches_tracked'
- if num_batches_tracked_key in state_dict:
- del state_dict[num_batches_tracked_key]
- super(FrozenBatchNorm2d, self)._load_from_state_dict(
- state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs)
- def forward(self, x):
- # move reshapes to the beginning
- # to make it fuser-friendly
- w = self.weight.reshape(1, -1, 1, 1)
- b = self.bias.reshape(1, -1, 1, 1)
- rv = self.running_var.reshape(1, -1, 1, 1)
- rm = self.running_mean.reshape(1, -1, 1, 1)
- eps = 1e-5
- scale = w * (rv + eps).rsqrt()
- bias = b - rm * scale
- return x * scale + bias
- # -------------------- ResNet series --------------------
- class ResNet(nn.Module):
- """Standard ResNet backbone."""
- def __init__(self,
- name :str = "resnet50",
- res5_dilation :bool = False,
- norm_type :str = "BN",
- freeze_at :int = 0,
- pretrained_weights :str = "imagenet1k_v1"):
- super().__init__()
- # Pretrained
- assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
- if pretrained_weights is not None:
- if name in ('resnet18', 'resnet34'):
- pretrained_weights = model_urls[name].IMAGENET1K_V1
- else:
- if pretrained_weights == "imagenet1k_v1":
- pretrained_weights = model_urls[name].IMAGENET1K_V1
- else:
- pretrained_weights = model_urls[name].IMAGENET1K_V2
- else:
- pretrained_weights = None
- print('- Backbone pretrained weight: ', pretrained_weights)
- # Norm layer
- print("- Norm layer of backbone: {}".format(norm_type))
- if norm_type == 'BN':
- norm_layer = nn.BatchNorm2d
- elif norm_type == 'FrozeBN':
- norm_layer = FrozenBatchNorm2d
- else:
- raise NotImplementedError("Unknown norm type: {}".format(norm_type))
- # Backbone
- backbone = getattr(torchvision.models, name)(
- replace_stride_with_dilation=[False, False, res5_dilation],
- norm_layer=norm_layer, weights=pretrained_weights)
- return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
- self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
- self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
-
- # Freeze
- print("- Freeze at {}".format(freeze_at))
- if freeze_at >= 0:
- for name, parameter in backbone.named_parameters():
- if freeze_at == 0: # Only freeze stem layer
- if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- elif freeze_at == 1: # Freeze stem layer + layer1
- if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- elif freeze_at == 2: # Freeze stem layer + layer1 + layer2
- if 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- elif freeze_at == 3: # Freeze stem layer + layer1 + layer2 + layer3
- if 'layer4' not in name:
- parameter.requires_grad_(False)
- else: # Freeze all resnet's layers
- parameter.requires_grad_(False)
- def forward(self, x):
- xs = self.body(x)
- fmp_list = []
- for name, fmp in xs.items():
- fmp_list.append(fmp)
- return fmp_list
- class SparkResNet(nn.Module):
- """ResNet backbone with SparK pretrained."""
- def __init__(self,
- name :str = "resnet50",
- res5_dilation :bool = False,
- norm_type :str = "BN",
- freeze_at :int = 0,
- pretrained :bool = False):
- super().__init__()
- # Norm layer
- print("- Norm layer of backbone: {}".format(norm_type))
- if norm_type == 'BN':
- norm_layer = nn.BatchNorm2d
- elif norm_type == 'FrozeBN':
- norm_layer = FrozenBatchNorm2d
- else:
- raise NotImplementedError("Unknown norm type: {}".format(norm_type))
- # Backbone
- backbone = getattr(torchvision.models, name)(
- replace_stride_with_dilation=[False, False, res5_dilation], norm_layer=norm_layer)
- return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
- self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
- self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
- # Load pretrained
- if pretrained:
- self.load_pretrained(name)
- # Freeze
- print("- Freeze at {}".format(freeze_at))
- if freeze_at >= 0:
- for name, parameter in backbone.named_parameters():
- if freeze_at == 0: # Only freeze stem layer
- if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- elif freeze_at == 1: # Freeze stem layer + layer1
- if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- elif freeze_at == 2: # Freeze stem layer + layer1 + layer2
- if 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- elif freeze_at == 3: # Freeze stem layer + layer1 + layer2 + layer3
- if 'layer4' not in name:
- parameter.requires_grad_(False)
- else: # Freeze all resnet's layers
- parameter.requires_grad_(False)
- def load_pretrained(self, name):
- url = spark_model_urls["spark_" + name]
- if url is not None:
- print('Loading backbone pretrained weight from : {}'.format(url))
- # checkpoint state dict
- checkpoint_state_dict = torch.hub.load_state_dict_from_url(
- url=url, map_location="cpu", check_hash=True)
- # model state dict
- model_state_dict = self.body.state_dict()
- # check
- for k in list(checkpoint_state_dict.keys()):
- if k in model_state_dict:
- shape_model = tuple(model_state_dict[k].shape)
- shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
- if shape_model != shape_checkpoint:
- checkpoint_state_dict.pop(k)
- else:
- checkpoint_state_dict.pop(k)
- print('Unused key: ', k)
- # load the weight
- self.body.load_state_dict(checkpoint_state_dict)
- else:
- print('No backbone pretrained for {}.'.format(name))
- def forward(self, x):
- xs = self.body(x)
- fmp_list = []
- for name, fmp in xs.items():
- fmp_list.append(fmp)
- return fmp_list
- # build backbone
- def build_resnet(cfg):
- # ResNet series
- if cfg['pretrained_weight'] in spark_model_urls.keys():
- backbone = SparkResNet(
- name = cfg['backbone'],
- res5_dilation = cfg['res5_dilation'],
- norm_type = cfg['backbone_norm'],
- pretrained = cfg['pretrained'],
- freeze_at = cfg['freeze_at'])
- else:
- backbone = ResNet(
- name = cfg['backbone'],
- res5_dilation = cfg['res5_dilation'],
- norm_type = cfg['backbone_norm'],
- pretrained_weights = cfg['pretrained_weight'],
- freeze_at = cfg['freeze_at'])
- return backbone, backbone.feat_dims
- if __name__ == '__main__':
- cfg = {
- 'backbone': 'resnet50',
- 'backbone_norm': 'FrozeBN',
- 'pretrained_weight': 'imagenet1k_v1',
- 'res5_dilation': False,
- 'freeze_at': 0,
- }
- model, feat_dim = build_resnet(cfg)
- print(feat_dim)
- x = torch.randn(2, 3, 320, 320)
- output = model(x)
- for y in output:
- print(y.size())
|