| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- """
- Backbone modules.
- """
- import torch
- import torchvision
- from torch import nn
- from torchvision.models._utils import IntermediateLayerGetter
- from torchvision.models.resnet import (ResNet18_Weights,
- ResNet34_Weights,
- ResNet50_Weights,
- ResNet101_Weights)
- model_urls = {
- # IN1K-Cls pretrained weights
- 'resnet18': ResNet18_Weights,
- 'resnet34': ResNet34_Weights,
- 'resnet50': ResNet50_Weights,
- 'resnet101': ResNet101_Weights,
- }
- # Frozen BatchNormazlizarion
- class FrozenBatchNorm2d(torch.nn.Module):
- """
- BatchNorm2d where the batch statistics and the affine parameters are fixed.
- Copy-paste from torchvision.misc.ops with added eps before rqsrt,
- without which any other models than torchvision.models.resnet[18,34,50,101]
- produce nans.
- """
- def __init__(self, n):
- super(FrozenBatchNorm2d, self).__init__()
- self.register_buffer("weight", torch.ones(n))
- self.register_buffer("bias", torch.zeros(n))
- self.register_buffer("running_mean", torch.zeros(n))
- self.register_buffer("running_var", torch.ones(n))
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- num_batches_tracked_key = prefix + 'num_batches_tracked'
- if num_batches_tracked_key in state_dict:
- del state_dict[num_batches_tracked_key]
- super(FrozenBatchNorm2d, self)._load_from_state_dict(
- state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs)
- def forward(self, x):
- # move reshapes to the beginning
- # to make it fuser-friendly
- w = self.weight.reshape(1, -1, 1, 1)
- b = self.bias.reshape(1, -1, 1, 1)
- rv = self.running_var.reshape(1, -1, 1, 1)
- rm = self.running_mean.reshape(1, -1, 1, 1)
- eps = 1e-5
- scale = w * (rv + eps).rsqrt()
- bias = b - rm * scale
- return x * scale + bias
- # -------------------- ResNet series --------------------
- class ResNet(nn.Module):
- """Standard ResNet backbone."""
- def __init__(self,
- name :str = "resnet50",
- res5_dilation :bool = False,
- norm_type :str = "BN",
- freeze_at :int = 0,
- use_pretrained :bool = False):
- super().__init__()
- # Pretrained
- if use_pretrained:
- pretrained_weights = model_urls[name].IMAGENET1K_V1
- else:
- pretrained_weights = None
- # Norm layer
- print("- Norm layer of backbone: {}".format(norm_type))
- if norm_type == 'BN':
- norm_layer = nn.BatchNorm2d
- elif norm_type == 'FrozeBN':
- norm_layer = FrozenBatchNorm2d
- else:
- raise NotImplementedError("Unknown norm type: {}".format(norm_type))
- # Backbone
- backbone = getattr(torchvision.models, name)(
- replace_stride_with_dilation=[False, False, res5_dilation],
- norm_layer=norm_layer, weights=pretrained_weights)
- return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
- self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
- self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
-
- # Freeze
- print("- Freeze at {}".format(freeze_at))
- if freeze_at >= 0:
- for name, parameter in backbone.named_parameters():
- if freeze_at == 0: # Only freeze stem layer
- if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- elif freeze_at == 1: # Freeze stem layer + layer1
- if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- elif freeze_at == 2: # Freeze stem layer + layer1 + layer2
- if 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- elif freeze_at == 3: # Freeze stem layer + layer1 + layer2 + layer3
- if 'layer4' not in name:
- parameter.requires_grad_(False)
- else: # Freeze all resnet's layers
- parameter.requires_grad_(False)
- def forward(self, x):
- xs = self.body(x)
- fmp_list = []
- for name, fmp in xs.items():
- fmp_list.append(fmp)
- return fmp_list
- # build backbone
- def build_resnet(cfg):
- # ResNet series
- backbone = ResNet(
- name = cfg.backbone,
- res5_dilation = cfg.res5_dilation,
- norm_type = cfg.bk_norm,
- use_pretrained = cfg.use_pretrained,
- freeze_at = cfg.freeze_at)
- return backbone, backbone.feat_dims
- if __name__ == '__main__':
- class FcosBaseConfig(object):
- def __init__(self):
- self.backbone = "resnet18"
- self.bk_norm = "FrozeBN"
- self.res5_dilation = False
- self.use_pretrained = True
- self.freeze_at = 0
- cfg = FcosBaseConfig()
- model, feat_dim = build_resnet(cfg)
- print(feat_dim)
- x = torch.randn(2, 3, 320, 320)
- output = model(x)
- for k in model.state_dict():
- print(k)
- for y in output:
- print(y.size())
|