| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- import torch
- import torchvision
- from torch import nn
- from torchvision.models._utils import IntermediateLayerGetter
- try:
- from .basic import FrozenBatchNorm2d
- except:
- from basic import FrozenBatchNorm2d
-
- # IN1K MIM pretrained weights (from SparK: https://github.com/keyu-tian/SparK)
- pretrained_urls = {
- # ResNet series
- 'resnet18': None,
- 'resnet34': None,
- 'resnet50': "https://github.com/yjh0410/RT-ODLab/releases/download/backbone_weight/resnet50_in1k_spark_pretrained_timm_style.pth",
- 'resnet101': None,
- # ShuffleNet series
- }
- # ----------------- Model functions -----------------
- ## Build backbone network
- def build_backbone(cfg, pretrained=False):
- print('==============================')
- print('Backbone: {}'.format(cfg['backbone']))
- # ResNet
- if 'resnet' in cfg['backbone']:
- model, feats = build_resnet(cfg, pretrained)
- elif 'svnetv2' in cfg['backbone']:
- pretrained_weight = cfg['pretrained_weight'] if pretrained else None
- model, feats = build_scnetv2(cfg, pretrained_weight)
- else:
- raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
-
- return model, feats
- # ----------------- ResNet Backbone -----------------
- class ResNet(nn.Module):
- """ResNet backbone with frozen BatchNorm."""
- def __init__(self,
- name: str,
- norm_type: str,
- pretrained: bool = False,
- freeze_at: int = -1,
- freeze_stem_only: bool = False):
- super().__init__()
- # Pretrained
- # Norm layer
- if norm_type == 'BN':
- norm_layer = nn.BatchNorm2d
- elif norm_type == 'FrozeBN':
- norm_layer = FrozenBatchNorm2d
- # Backbone
- backbone = getattr(torchvision.models, name)(norm_layer=norm_layer,)
- return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
- self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
- self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
-
- # Load pretrained
- if pretrained:
- self.load_pretrained(name)
- # Freeze
- if freeze_at >= 0:
- for name, parameter in backbone.named_parameters():
- if freeze_stem_only:
- if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- else:
- if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- def load_pretrained(self, name):
- url = pretrained_urls[name]
- if url is not None:
- print('Loading pretrained weight from : {}'.format(url))
- # checkpoint state dict
- checkpoint_state_dict = torch.hub.load_state_dict_from_url(
- url=url, map_location="cpu", check_hash=True)
- # model state dict
- model_state_dict = self.body.state_dict()
- # check
- for k in list(checkpoint_state_dict.keys()):
- if k in model_state_dict:
- shape_model = tuple(model_state_dict[k].shape)
- shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
- if shape_model != shape_checkpoint:
- checkpoint_state_dict.pop(k)
- else:
- checkpoint_state_dict.pop(k)
- print('Unused key: ', k)
- # load the weight
- self.body.load_state_dict(checkpoint_state_dict)
- else:
- print('No backbone pretrained for {}.'.format(name))
- def forward(self, x):
- xs = self.body(x)
- fmp_list = []
- for name, fmp in xs.items():
- fmp_list.append(fmp)
- return fmp_list
- def build_resnet(cfg, pretrained=False):
- # ResNet series
- backbone = ResNet(cfg['backbone'],
- cfg['backbone_norm'],
- pretrained,
- cfg['freeze_at'],
- cfg['freeze_stem_only'])
- return backbone, backbone.feat_dims
- # ----------------- ShuffleNet Backbone -----------------
- ## TODO: Add shufflenet-v2
- class ShuffleNetv2:
- pass
- def build_scnetv2(cfg, pretrained_weight=None):
- return
- if __name__ == '__main__':
- cfg = {
- 'backbone': 'resnet50',
- 'backbone_norm': 'FrozeBN',
- 'pretrained': True,
- 'freeze_at': 0,
- 'freeze_stem_only': False,
- }
- model, feat_dim = build_backbone(cfg, cfg['pretrained'])
- model.eval()
- print(feat_dim)
- x = torch.ones(2, 3, 320, 320)
- output = model(x)
- for y in output:
- print(y.size())
- print(output[-1])
|