浏览代码

modify rtdetr_r18 config

yjh0410 1 年之前
父节点
当前提交
e36357bed7
共有 2 个文件被更改,包括 18 次插入8 次删除
  1. 1 0
      config/model_config/rtdetr_config.py
  2. 17 8
      models/detectors/rtdetr/basic_modules/backbone.py

+ 1 - 0
config/model_config/rtdetr_config.py

@@ -14,6 +14,7 @@ rtdetr_cfg = {
         'res5_dilation': False,
         'pretrained': True,
         'pretrained_weight': 'imagenet1k_v1',
+        'freeze_at': -1,
         'freeze_stem_only': True,
         'out_stride': [8, 16, 32],
         'max_stride': 32,

+ 17 - 8
models/detectors/rtdetr/basic_modules/backbone.py

@@ -49,6 +49,7 @@ class ResNet(nn.Module):
                  res5_dilation: bool,
                  norm_type: str,
                  pretrained_weights: str = "imagenet1k_v1",
+                 freeze_at: int = -1,
                  freeze_stem_only: bool = False):
         super().__init__()
         # Pretrained
@@ -77,13 +78,14 @@ class ResNet(nn.Module):
         self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
         self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
         # Freeze
-        for name, parameter in backbone.named_parameters():
-            if freeze_stem_only:
-                if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
-                    parameter.requires_grad_(False)
-            else:
-                if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
-                    parameter.requires_grad_(False)
+        if freeze_at >= 0:
+            for name, parameter in backbone.named_parameters():
+                if freeze_stem_only:
+                    if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                        parameter.requires_grad_(False)
+                else:
+                    if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                        parameter.requires_grad_(False)
 
     def forward(self, x):
         xs = self.body(x)
@@ -95,7 +97,12 @@ class ResNet(nn.Module):
 
 def build_resnet(cfg, pretrained_weight=None):
     # ResNet series
-    backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight, cfg['freeze_stem_only'])
+    backbone = ResNet(cfg['backbone'],
+                      cfg['res5_dilation'],
+                      cfg['backbone_norm'],
+                      pretrained_weight,
+                      cfg['freeze_at'],
+                      cfg['freeze_stem_only'])
 
     return backbone, backbone.feat_dims
 
@@ -115,6 +122,8 @@ if __name__ == '__main__':
         'backbone_norm': 'BN',
         'res5_dilation': False,
         'pretrained': True,
+        'freeze_at': -1,
+        'freeze_stem_only': True,
         'pretrained_weight': 'imagenet1k_v1',
     }
     model, feat_dim = build_backbone(cfg, cfg['pretrained'])