Sfoglia il codice sorgente

debug YOLOX-style Transform with Rotation

yjh0410 2 anni fa
parent
commit
73011ee555

+ 3 - 3
config/model_config/yolox_config.py

@@ -94,7 +94,7 @@ yolox_cfg = {
         # ---------------- Model config ----------------
         ## Backbone
         'backbone': 'cspdarknet',
-        'pretrained': True,
+        'pretrained': False,
         'bk_act': 'silu',
         'bk_norm': 'BN',
         'bk_dpw': False,
@@ -138,7 +138,7 @@ yolox_cfg = {
         # ---------------- Model config ----------------
         ## Backbone
         'backbone': 'cspdarknet',
-        'pretrained': True,
+        'pretrained': False,
         'bk_act': 'silu',
         'bk_norm': 'BN',
         'bk_dpw': False,
@@ -182,7 +182,7 @@ yolox_cfg = {
         # ---------------- Model config ----------------
         ## Backbone
         'backbone': 'cspdarknet',
-        'pretrained': True,
+        'pretrained': False,
         'bk_act': 'silu',
         'bk_norm': 'BN',
         'bk_dpw': False,

+ 3 - 3
engine.py

@@ -30,7 +30,7 @@ class YoloTrainer(object):
         self.epoch = 0
         self.best_map = -1.
         self.last_opt_step = 0
-        self.no_aug_epoch = 20
+        self.no_aug_epoch = args.no_aug_epoch
         self.clip_grad = 10
         self.device = device
         self.criterion = criterion
@@ -327,7 +327,7 @@ class RTMTrainer(object):
         self.device = device
         self.criterion = criterion
         self.world_size = world_size
-        self.no_aug_epoch = 20
+        self.no_aug_epoch = args.no_aug_epoch
         self.clip_grad = 35
         self.heavy_eval = False
         self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
@@ -610,7 +610,7 @@ class DetrTrainer(object):
         self.epoch = 0
         self.best_map = -1.
         self.last_opt_step = 0
-        self.no_aug_epoch = 20
+        self.no_aug_epoch = args.no_aug_epoch
         self.clip_grad = -1
         self.device = device
         self.criterion = criterion

+ 1 - 1
models/detectors/yolox/build.py

@@ -60,5 +60,5 @@ def build_yolox(args, cfg, device, num_classes=80, trainable=False, deploy=False
     criterion = None
     if trainable:
         # build criterion for training
-        criterion = build_criterion(cfg, device, num_classes)
+        criterion = build_criterion(args, cfg, device, num_classes)
     return model, criterion

+ 62 - 5
models/detectors/yolox/loss.py

@@ -7,13 +7,18 @@ from utils.distributed_utils import get_world_size, is_dist_avail_and_initialize
 
 
 class Criterion(object):
-    def __init__(self, 
+    def __init__(self,
+                 args,
                  cfg, 
                  device, 
                  num_classes=80):
+        self.args = args
         self.cfg = cfg
         self.device = device
         self.num_classes = num_classes
+        self.max_epoch = args.max_epoch
+        self.no_aug_epoch = args.no_aug_epoch
+        self.aux_bbox_loss = False
         # loss weight
         self.loss_obj_weight = cfg['loss_obj_weight']
         self.loss_cls_weight = cfg['loss_cls_weight']
@@ -47,11 +52,26 @@ class Criterion(object):
         return loss_box
 
 
+    def loss_bboxes_aux(self, pred_reg, gt_box, anchors, stride_tensors):
+        # xyxy -> cxcy&bwbh
+        gt_cxcy = (gt_box[..., :2] + gt_box[..., 2:]) * 0.5
+        gt_bwbh = gt_box[..., 2:] - gt_box[..., :2]
+        # encode gt box
+        gt_cxcy_encode = (gt_cxcy - anchors) / stride_tensors
+        gt_bwbh_encode = torch.log(gt_bwbh / stride_tensors)
+        gt_box_encode = torch.cat([gt_cxcy_encode, gt_bwbh_encode], dim=-1)
+        # l1 loss
+        loss_box_aux = F.l1_loss(pred_reg, gt_box_encode, reduction='none')
+
+        return loss_box_aux
+
+
     def __call__(self, outputs, targets, epoch=0):        
         """
             outputs['pred_obj']: List(Tensor) [B, M, 1]
             outputs['pred_cls']: List(Tensor) [B, M, C]
             outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
             outputs['strides']: List(Int) [8, 16, 32] output stride
             targets: (List) [dict{'boxes': [...], 
                                  'labels': [...], 
@@ -120,16 +140,16 @@ class Criterion(object):
             torch.distributed.all_reduce(num_fgs)
         num_fgs = (num_fgs / get_world_size()).clamp(1.0)
 
-        # ------------------ objecntness loss ------------------
+        # ------------------ Objecntness loss ------------------
         loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
         loss_obj = loss_obj.sum() / num_fgs
         
-        # ------------------ classification loss ------------------
+        # ------------------ Classification loss ------------------
         cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
         loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
         loss_cls = loss_cls.sum() / num_fgs
 
-        # ------------------ regression loss ------------------
+        # ------------------ Regression loss ------------------
         box_preds_pos = box_preds.view(-1, 4)[fg_masks]
         loss_box = self.loss_bboxes(box_preds_pos, box_targets)
         loss_box = loss_box.sum() / num_fgs
@@ -146,11 +166,48 @@ class Criterion(object):
                 losses = losses
         )
 
+        # ------------------ Aux regression loss ------------------
+        loss_box_aux = None
+        if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
+            ## reg_preds
+            reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+            reg_preds_pos = reg_preds.view(-1, 4)[fg_masks]
+            ## anchor tensors
+            anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
+            anchors_tensors_pos = anchors_tensors.view(-1, 2)[fg_masks]
+            ## stride tensors
+            stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
+            stride_tensors_pos = stride_tensors.view(-1, 1)[fg_masks]
+            ## aux loss
+            loss_box_aux = self.loss_bboxes_aux(reg_preds_pos, box_targets, anchors_tensors_pos, stride_tensors_pos)
+            loss_box_aux = loss_box_aux.sum() / num_fgs
+
+            losses += loss_box_aux
+
+        # Loss dict
+        if loss_box_aux is None:
+            loss_dict = dict(
+                    loss_obj = loss_obj,
+                    loss_cls = loss_cls,
+                    loss_box = loss_box,
+                    losses = losses
+            )
+        else:
+            loss_dict = dict(
+                    loss_obj = loss_obj,
+                    loss_cls = loss_cls,
+                    loss_box = loss_box,
+                    loss_box_aux = loss_box_aux,
+                    losses = losses
+                    )
+
+
         return loss_dict
     
 
-def build_criterion(cfg, device, num_classes):
+def build_criterion(args, cfg, device, num_classes):
     criterion = Criterion(
+        args=args,
         cfg=cfg,
         device=device,
         num_classes=num_classes

+ 13 - 3
models/detectors/yolox/yolox.py

@@ -205,9 +205,11 @@ class YOLOX(nn.Module):
 
             # 检测头
             all_anchors = []
+            all_strides = []
             all_obj_preds = []
             all_cls_preds = []
             all_box_preds = []
+            all_reg_preds = []
             for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
                 cls_feat, reg_feat = head(feat)
 
@@ -220,7 +222,10 @@ class YOLOX(nn.Module):
                 fmp_size = [H, W]
                 # generate anchor boxes: [M, 4]
                 anchors = self.generate_anchors(level, fmp_size)
-                
+
+                # stride tensor: [M, 1]
+                stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
+
                 # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
                 obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
                 cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
@@ -236,13 +241,18 @@ class YOLOX(nn.Module):
                 all_obj_preds.append(obj_pred)
                 all_cls_preds.append(cls_pred)
                 all_box_preds.append(box_pred)
+                all_reg_preds.append(reg_pred)
                 all_anchors.append(anchors)
+                all_strides.append(stride_tensor)
             
             # output dict
             outputs = {"pred_obj": all_obj_preds,        # List(Tensor) [B, M, 1]
                        "pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
                        "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
-                       "anchors": all_anchors,           # List(Tensor) [B, M, 2]
-                       'strides': self.stride}           # List(Int) [8, 16, 32]
+                       "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4]
+                       "anchors": all_anchors,           # List(Tensor) [M, 2]
+                       "strides": self.stride,           # List(Int) [8, 16, 32]
+                       "stride_tensors": all_strides     # List(Tensor) [M, 1]
+                       }
 
             return outputs 

+ 2 - 0
train.py

@@ -55,6 +55,8 @@ def parse_args():
                         help='warmup epoch.')
     parser.add_argument('--eval_epoch', default=10, type=int, 
                         help='after eval epoch, the model is evaluated on val dataset.')
+    parser.add_argument('--no_aug_epoch', default=20, type=int, 
+                        help='cancel strong augmentation.')
 
     # Model
     parser.add_argument('-m', '--model', default='yolov1', type=str,

+ 6 - 5
train.sh

@@ -1,14 +1,15 @@
 # Train YOLO
 python train.py \
         --cuda \
-        -d coco \
+        -d voc \
         --root /mnt/share/ssd2/dataset/ \
-        -m yolovx_l \
+        -m yolox_n \
         -bs 16 \
         -size 640 \
-        --wp_epoch 3 \
-        --max_epoch 300 \
-        --eval_epoch 10 \
+        --wp_epoch 1 \
+        --max_epoch 30 \
+        --eval_epoch 5 \
+        --no_aug_epoch 20 \
         --ema \
         --fp16 \
         --multi_scale \

+ 1 - 0
train_ddp.sh

@@ -11,6 +11,7 @@ python -m torch.distributed.run --nproc_per_node=8 train.py \
                                                     --wp_epoch 3 \
                                                     --max_epoch 300 \
                                                     --eval_epoch 10 \
+                                                    --no_aug_epoch 20 \
                                                     --ema \
                                                     --fp16 \
                                                     --sybn \