|
|
@@ -44,7 +44,12 @@ def build_backbone(cfg, pretrained):
|
|
|
# ----------------- ResNet Backbone -----------------
|
|
|
class ResNet(nn.Module):
|
|
|
"""ResNet backbone with frozen BatchNorm."""
|
|
|
- def __init__(self, name: str, res5_dilation: bool, norm_type: str, pretrained_weights: str = "imagenet1k_v1"):
|
|
|
+ def __init__(self,
|
|
|
+ name: str,
|
|
|
+ res5_dilation: bool,
|
|
|
+ norm_type: str,
|
|
|
+ pretrained_weights: str = "imagenet1k_v1",
|
|
|
+ freeze_stem_only: bool = False):
|
|
|
super().__init__()
|
|
|
# Pretrained
|
|
|
assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
|
|
|
@@ -73,8 +78,12 @@ class ResNet(nn.Module):
|
|
|
self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
|
|
|
# Freeze
|
|
|
for name, parameter in backbone.named_parameters():
|
|
|
- if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
|
|
- parameter.requires_grad_(False)
|
|
|
+ if freeze_stem_only:
|
|
|
+ if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
|
|
+ parameter.requires_grad_(False)
|
|
|
+ else:
|
|
|
+ if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
|
|
+ parameter.requires_grad_(False)
|
|
|
|
|
|
def forward(self, x):
|
|
|
xs = self.body(x)
|
|
|
@@ -86,7 +95,7 @@ class ResNet(nn.Module):
|
|
|
|
|
|
def build_resnet(cfg, pretrained_weight=None):
|
|
|
# ResNet series
|
|
|
- backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight)
|
|
|
+ backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight, cfg['freeze_stem_only'])
|
|
|
|
|
|
return backbone, backbone.feat_dims
|
|
|
|