Bläddra i källkod

train RT-DETR-R18 on COCO

yjh0410 1 år sedan
förälder
incheckning
b6a1ff72a3

+ 1 - 0
config/model_config/rtdetr_config.py

@@ -14,6 +14,7 @@ rtdetr_cfg = {
         'res5_dilation': False,
         'pretrained': True,
         'pretrained_weight': 'imagenet1k_v1',
+        'freeze_stem_only': True,
         'out_stride': [8, 16, 32],
         ## Image Encoder - FPN
         'fpn': 'hybrid_encoder',

+ 13 - 4
models/detectors/rtdetr/basic_modules/backbone.py

@@ -44,7 +44,12 @@ def build_backbone(cfg, pretrained):
 # ----------------- ResNet Backbone -----------------
 class ResNet(nn.Module):
     """ResNet backbone with frozen BatchNorm."""
-    def __init__(self, name: str, res5_dilation: bool, norm_type: str, pretrained_weights: str = "imagenet1k_v1"):
+    def __init__(self,
+                 name: str,
+                 res5_dilation: bool,
+                 norm_type: str,
+                 pretrained_weights: str = "imagenet1k_v1",
+                 freeze_stem_only: bool = False):
         super().__init__()
         # Pretrained
         assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
@@ -73,8 +78,12 @@ class ResNet(nn.Module):
         self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
         # Freeze
         for name, parameter in backbone.named_parameters():
-            if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
-                parameter.requires_grad_(False)
+            if freeze_stem_only:
+                if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                    parameter.requires_grad_(False)
+            else:
+                if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                    parameter.requires_grad_(False)
 
     def forward(self, x):
         xs = self.body(x)
@@ -86,7 +95,7 @@ class ResNet(nn.Module):
 
 def build_resnet(cfg, pretrained_weight=None):
     # ResNet series
-    backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight)
+    backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight, cfg['freeze_stem_only'])
 
     return backbone, backbone.feat_dims
 

+ 2 - 1
models/detectors/rtdetr/matcher.py

@@ -38,7 +38,8 @@ class HungarianMatcher(nn.Module):
         ## L1 cost: [Nq, M]
         cost_bbox = torch.cdist(out_bbox, tgt_bbox.to(out_bbox.device), p=1)
         ## GIoU cost: Nq, M]
-        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox).to(out_bbox.device))
+        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
+                                         box_cxcywh_to_xyxy(tgt_bbox).to(out_bbox.device))
 
         # Final cost: [B, Nq, M]
         C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou

+ 3 - 2
train.py

@@ -119,7 +119,8 @@ def parse_args():
                         help='number of distributed processes')
     parser.add_argument('--sybn', action='store_true', default=False, 
                         help='use sybn.')
-    
+    parser.add_argument('--find_unused_parameters', default=False, type=bool,
+                        help='set find_unused_parameters as True.')
     # Debug mode
     parser.add_argument('--debug', action='store_true', default=False, 
                         help='debug mode.')
@@ -180,7 +181,7 @@ def train():
         print('use SyncBatchNorm ...')
         model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
     if args.distributed:
-        model = DDP(model, device_ids=[args.gpu])
+        model = DDP(model, device_ids=[args.gpu], find_unused_parameters=args.find_unused_parameters)
         model_without_ddp = model.module
     ## Calcute Params & GFLOPs
     if distributed_utils.is_main_process:

+ 4 - 0
train.sh

@@ -9,6 +9,7 @@ RESUME=$7
 
 # MODEL setting
 IMAGE_SIZE=640
+FIND_UNUSED_PARAMS=False
 if [[ $MODEL == *"rtcdet"* ]]; then
     # Epoch setting
     MAX_EPOCH=500
@@ -21,6 +22,7 @@ elif [[ $MODEL == *"rtdetr"* ]]; then
     WP_EPOCH=-1
     EVAL_EPOCH=4
     NO_AUG_EPOCH=-1
+    FIND_UNUSED_PARAMS=True
 elif [[ $MODEL == *"yolov8"* ]]; then
     # Epoch setting
     MAX_EPOCH=500
@@ -81,6 +83,7 @@ if [ $WORLD_SIZE == 1 ]; then
             --resume ${RESUME} \
             --ema \
             --fp16 \
+            --find_unused_parameters ${FIND_UNUSED_PARAMS} \
             --multi_scale
 elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
     python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
@@ -98,6 +101,7 @@ elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
             --resume ${RESUME} \
             --ema \
             --fp16 \
+            --find_unused_parameters ${FIND_UNUSED_PARAMS} \
             --multi_scale \
             --sybn
 else