| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import torch.nn as nn
- import torchvision
- from torchvision.models._utils import IntermediateLayerGetter
- from torchvision.models import resnet
- try:
- from .norm import FrozenBatchNorm2d
- except:
- from norm import FrozenBatchNorm2d
- # ----------------- Model functions -----------------
- ## Build backbone network
- def build_backbone(cfg, pretrained):
- print('==============================')
- print('Backbone: {}'.format(cfg.backbone))
- # ResNet
- if 'resnet' in cfg.backbone:
- pretrained_weight = cfg.pretrained_weight if pretrained else None
- model = build_resnet(cfg, pretrained_weight)
- else:
- raise NotImplementedError("Unknown backbone: <>.".format(cfg.backbone))
-
- return model
- # ----------------- ResNet Backbone -----------------
- class ResNet(nn.Module):
- """ResNet backbone with frozen BatchNorm."""
- def __init__(self,
- name: str,
- norm_type: str,
- pretrained_weights: str = "imagenet1k_v1",
- freeze_at: int = -1,
- freeze_stem_only: bool = False):
- super().__init__()
- # Pretrained
- assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
- if pretrained_weights is not None:
- if name == "resnet18":
- pretrained_weights = resnet.ResNet18_Weights.IMAGENET1K_V1
- elif name == "resnet34":
- pretrained_weights = resnet.ResNet34_Weights.IMAGENET1K_V1
- elif name == "resnet50" and pretrained_weights == "imagenet1k_v1":
- pretrained_weights = resnet.ResNet50_Weights.IMAGENET1K_V1
- elif name == "resnet50" and pretrained_weights == "imagenet1k_v2":
- pretrained_weights = resnet.ResNet50_Weights.IMAGENET1K_V2
- elif name == "resnet101" and pretrained_weights == "imagenet1k_v1":
- pretrained_weights = resnet.ResNet101_Weights.IMAGENET1K_V1
- elif name == "resnet101" and pretrained_weights == "imagenet1k_v2":
- pretrained_weights = resnet.ResNet101_Weights.IMAGENET1K_V2
- else:
- pretrained_weights = None
- print('- Backbone pretrained weight: ', pretrained_weights)
- # Norm layer
- print("- Norm layer of backbone: {}".format(norm_type))
- if norm_type == 'BN':
- norm_layer = nn.BatchNorm2d
- elif norm_type == 'FrozeBN':
- norm_layer = FrozenBatchNorm2d
- # Backbone
- backbone = getattr(torchvision.models, name)(norm_layer=norm_layer, weights=pretrained_weights)
- return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
- self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
- self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
- # Freeze
- print("- Freeze at: {}".format(freeze_at))
- if freeze_at >= 0:
- for name, parameter in backbone.named_parameters():
- if freeze_stem_only:
- print("- Freeze stem layer only")
- if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- else:
- print("- Freeze stem layer only + layer1")
- if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
- parameter.requires_grad_(False)
- def forward(self, x):
- xs = self.body(x)
- fmp_list = []
- for name, fmp in xs.items():
- fmp_list.append(fmp)
- return fmp_list
- def build_resnet(cfg, pretrained_weight=None):
- # ResNet series
- backbone = ResNet(cfg.backbone,
- cfg.backbone_norm,
- pretrained_weight,
- cfg.freeze_at,
- cfg.freeze_stem_only)
- return backbone
|