|
|
@@ -49,6 +49,7 @@ class ResNet(nn.Module):
|
|
|
res5_dilation: bool,
|
|
|
norm_type: str,
|
|
|
pretrained_weights: str = "imagenet1k_v1",
|
|
|
+ freeze_at: int = -1,
|
|
|
freeze_stem_only: bool = False):
|
|
|
super().__init__()
|
|
|
# Pretrained
|
|
|
@@ -77,13 +78,14 @@ class ResNet(nn.Module):
|
|
|
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
|
|
self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
|
|
|
# Freeze
|
|
|
- for name, parameter in backbone.named_parameters():
|
|
|
- if freeze_stem_only:
|
|
|
- if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
|
|
- parameter.requires_grad_(False)
|
|
|
- else:
|
|
|
- if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
|
|
- parameter.requires_grad_(False)
|
|
|
+ if freeze_at >= 0:
|
|
|
+ for name, parameter in backbone.named_parameters():
|
|
|
+ if freeze_stem_only:
|
|
|
+ if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
|
|
+ parameter.requires_grad_(False)
|
|
|
+ else:
|
|
|
+ if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
|
|
+ parameter.requires_grad_(False)
|
|
|
|
|
|
def forward(self, x):
|
|
|
xs = self.body(x)
|
|
|
@@ -95,7 +97,12 @@ class ResNet(nn.Module):
|
|
|
|
|
|
def build_resnet(cfg, pretrained_weight=None):
|
|
|
# ResNet series
|
|
|
- backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight, cfg['freeze_stem_only'])
|
|
|
+ backbone = ResNet(cfg['backbone'],
|
|
|
+ cfg['res5_dilation'],
|
|
|
+ cfg['backbone_norm'],
|
|
|
+ pretrained_weight,
|
|
|
+ cfg['freeze_at'],
|
|
|
+ cfg['freeze_stem_only'])
|
|
|
|
|
|
return backbone, backbone.feat_dims
|
|
|
|
|
|
@@ -115,6 +122,8 @@ if __name__ == '__main__':
|
|
|
'backbone_norm': 'BN',
|
|
|
'res5_dilation': False,
|
|
|
'pretrained': True,
|
|
|
+ 'freeze_at': -1,
|
|
|
+ 'freeze_stem_only': True,
|
|
|
'pretrained_weight': 'imagenet1k_v1',
|
|
|
}
|
|
|
model, feat_dim = build_backbone(cfg, cfg['pretrained'])
|