Explorar o código

modify RT-PlainDETR's Trainer

yjh0410 hai 1 ano
pai
achega
6df1f9fad1

+ 2 - 2
config/model_config/rtpdetr_config.py

@@ -18,7 +18,7 @@ rtpdetr_cfg = {
         'freeze_stem_only': False,
         'hidden_dim': 256,
         'en_num_heads': 8,
-        'en_num_layers': 1,
+        'en_num_layers': 6,
         'en_mlp_ratio': 4.0,
         'en_dropout': 0.0,
         'en_act': 'gelu',
@@ -35,7 +35,7 @@ rtpdetr_cfg = {
         'proposal_feature_levels': 3,
         'proposal_tgt_strides': [8, 16, 32],
         'num_queries_one2one': 300,
-        'num_queries_one2many': 10,
+        'num_queries_one2many': 1500,
         # ---------------- Assignment config ----------------
         'matcher_hpy': {'cost_class': 2.0,
                         'cost_bbox': 1.0,

+ 379 - 266
engine.py

@@ -1121,7 +1121,7 @@ class RTCTrainer(object):
         self.train_transform, self.trans_cfg = build_transform(
             args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.train_loader.dataset.transform = self.train_transform
-   
+
 ## Real-time DETR Trainer
 class RTDetrTrainer(object):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
@@ -1135,6 +1135,7 @@ class RTDetrTrainer(object):
         self.grad_accumulate = args.grad_accumulate
         self.clip_grad = 0.1
         self.heavy_eval = False
+        self.normalize_bbox = True
         # close AMP for RT-DETR
         self.args.fp16 = False
         # weak augmentatino stage
@@ -1316,7 +1317,7 @@ class RTDetrTrainer(object):
             # Visualize train targets
             if self.args.vis_tgt:
                 targets = self.box_cxcywh_to_xyxy(targets)
-                vis_data(images, targets, normalized_bbox=True,
+                vis_data(images, targets, normalized_bbox=self.normalize_bbox,
                          pixel_mean=self.trans_cfg['pixel_mean'], pixel_std=self.trans_cfg['pixel_std'])
                 targets = self.box_xyxy_to_cxcywh(targets)
 
@@ -1374,9 +1375,10 @@ class RTDetrTrainer(object):
             tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
             min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
             keep = (min_tgt_size >= min_box_size)
-            # normalize box
-            boxes[:, [0, 2]] = boxes[:, [0, 2]] / img_size
-            boxes[:, [1, 3]] = boxes[:, [1, 3]] / img_size
+            if self.normalize_bbox:
+                # normalize box
+                boxes[:, [0, 2]] = boxes[:, [0, 2]] / img_size
+                boxes[:, [1, 3]] = boxes[:, [1, 3]] / img_size
 
             tgt["boxes"] = boxes[keep]
             tgt["labels"] = labels[keep]
@@ -1415,9 +1417,10 @@ class RTDetrTrainer(object):
             tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
             min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
             keep = (min_tgt_size >= min_box_size)
-            # normalize box
-            boxes[:, [0, 2]] = boxes[:, [0, 2]] / new_img_size
-            boxes[:, [1, 3]] = boxes[:, [1, 3]] / new_img_size
+            if self.normalize_bbox:
+                # normalize box
+                boxes[:, [0, 2]] = boxes[:, [0, 2]] / new_img_size
+                boxes[:, [1, 3]] = boxes[:, [1, 3]] / new_img_size
 
             tgt["boxes"] = boxes[keep]
             tgt["labels"] = labels[keep]
@@ -1485,159 +1488,29 @@ class RTDetrTrainer(object):
         
         self.train_transform.set_weak_augment()
         self.train_loader.dataset.transform = self.train_transform
-        
+
 ## Real-time PlainDETR Trainer
-class RTPDetrTrainer(object):
+class RTPDetrTrainer(RTDetrTrainer):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
+        super().__init__(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
         # ------------------- Basic parameters -------------------
-        self.args = args
-        self.epoch = 0
-        self.best_map = -1.
-        self.device = device
-        self.criterion = criterion
-        self.world_size = world_size
-        self.grad_accumulate = args.grad_accumulate
-        self.clip_grad = 0.1
-        self.heavy_eval = False
-        # close AMP for RT-DETR
-        self.args.fp16 = False
-        # weak augmentatino stage
-        self.second_stage = False
-        self.second_stage_epoch = -1
-        # path to save model
-        self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
-        os.makedirs(self.path_to_save, exist_ok=True)
-
-        # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
-        self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001, 'backbone_lr_ratio': 0.1}
-        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.1, 'warmup_iters': 2000} # no lr decay
-        self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
-
-        # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
-        self.data_cfg  = data_cfg
-        self.model_cfg = model_cfg
-        self.trans_cfg = trans_cfg
-
-        # ---------------------------- Build Transform ----------------------------
-        self.train_transform, self.trans_cfg = build_transform(
-            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
-        self.val_transform, _ = build_transform(
-            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
-        if self.trans_cfg["mosaic_prob"] > 0.5:
-            self.second_stage_epoch = 5
-
-        # ---------------------------- Build Dataset & Dataloader ----------------------------
-        self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
-        self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
-
-        # ---------------------------- Build Evaluator ----------------------------
-        self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
-
-        # ---------------------------- Build Grad. Scaler ----------------------------
-        self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
+        ## Reset optimzier hyper-parameters
+        self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 0.05, 'lr0': 0.0002, 'backbone_lr_ratio': 0.1}
+        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 1.0, 'warmup_iters': 1000} # no lr decay
+        self.normalize_bbox = False
 
         # ---------------------------- Build Optimizer ----------------------------
+        print("- Re-build oprimizer")
         self.optimizer_dict['lr0'] *= self.args.batch_size / 16.  # auto lr scaling
         self.optimizer, self.start_epoch = build_rtdetr_optimizer(self.optimizer_dict, model, self.args.resume)
 
         # ---------------------------- Build LR Scheduler ----------------------------
+        print("- Re-build lr scheduler")
         self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch)
         self.lr_scheduler.last_epoch = self.start_epoch - 1  # do not move
         if self.args.resume and self.args.resume != 'None':
             self.lr_scheduler.step()
 
-        # ---------------------------- Build Model-EMA ----------------------------
-        if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
-            print('Build ModelEMA ...')
-            self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
-        else:
-            self.model_ema = None
-
-    def train(self, model):
-        for epoch in range(self.start_epoch, self.args.max_epoch):
-            if self.args.distributed:
-                self.train_loader.batch_sampler.sampler.set_epoch(epoch)
-
-            # check second stage
-            if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
-                self.check_second_stage()
-                # save model of the last mosaic epoch
-                weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
-                checkpoint_path = os.path.join(self.path_to_save, weight_name)
-                print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
-                torch.save({'model': model.state_dict(),
-                            'mAP': round(self.evaluator.map*100, 1),
-                            'optimizer': self.optimizer.state_dict(),
-                            'epoch': self.epoch,
-                            'args': self.args}, 
-                            checkpoint_path)
-
-            # train one epoch
-            self.epoch = epoch
-            self.train_one_epoch(model)
-
-            # eval one epoch
-            if self.heavy_eval:
-                model_eval = model.module if self.args.distributed else model
-                self.eval(model_eval)
-            else:
-                model_eval = model.module if self.args.distributed else model
-                if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
-                    self.eval(model_eval)
-
-            if self.args.debug:
-                print("For debug mode, we only train 1 epoch")
-                break
-
-    def eval(self, model):
-        # chech model
-        model_eval = model if self.model_ema is None else self.model_ema.ema
-
-        if distributed_utils.is_main_process():
-            # check evaluator
-            if self.evaluator is None:
-                print('No evaluator ... save model and go on training.')
-                print('Saving state, epoch: {}'.format(self.epoch))
-                weight_name = '{}_no_eval.pth'.format(self.args.model)
-                checkpoint_path = os.path.join(self.path_to_save, weight_name)
-                torch.save({'model': model_eval.state_dict(),
-                            'mAP': -1.,
-                            'optimizer': self.optimizer.state_dict(),
-                            'epoch': self.epoch,
-                            'args': self.args}, 
-                            checkpoint_path)               
-            else:
-                print('eval ...')
-                # set eval mode
-                model_eval.eval()
-
-                # evaluate
-                with torch.no_grad():
-                    self.evaluator.evaluate(model_eval)
-
-                # save model
-                cur_map = self.evaluator.map
-                if cur_map > self.best_map:
-                    # update best-map
-                    self.best_map = cur_map
-                    # save model
-                    print('Saving state, epoch:', self.epoch)
-                    weight_name = '{}_best.pth'.format(self.args.model)
-                    checkpoint_path = os.path.join(self.path_to_save, weight_name)
-                    torch.save({'model': model_eval.state_dict(),
-                                'mAP': round(self.best_map*100, 1),
-                                'optimizer': self.optimizer.state_dict(),
-                                'epoch': self.epoch,
-                                'args': self.args}, 
-                                checkpoint_path)                      
-
-                # set train mode.
-                model_eval.train()
-
-        if self.args.distributed:
-            # wait for all processes to synchronize
-            dist.barrier()
-
     def train_one_epoch(self, model):
         metric_logger = MetricLogger(delimiter="  ")
         metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
@@ -1685,18 +1558,12 @@ class RTPDetrTrainer(object):
                 outputs = model(images)
                 # Compute loss
                 loss_dict = self.criterion(outputs, targets)
-                loss_weight_dict = self.criterion.weight_dict
-                losses = sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict)
-
+                losses = sum(loss_dict.values())
                 # Grad Accumulate
                 if self.grad_accumulate > 1:
                     losses /= self.grad_accumulate
 
-                # Reduce losses over all GPUs for logging purposes
                 loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
-                loss_dict_reduced_scaled = {k: v * loss_weight_dict[k] for k, v in loss_dict_reduced.items() if k in loss_weight_dict}
-                losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
-                loss_value = losses_reduced_scaled.item()
 
             # Backward
             self.scaler.scale(losses).backward()
@@ -1718,7 +1585,7 @@ class RTPDetrTrainer(object):
                     self.model_ema.update(model)
 
             # Update log
-            metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled)
+            metric_logger.update(loss=losses.item(), **loss_dict_reduced)
             metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
             metric_logger.update(grad_norm=grad_norm)
             metric_logger.update(size=img_size)
@@ -1731,120 +1598,366 @@ class RTPDetrTrainer(object):
         if not self.second_stage:
             self.lr_scheduler.step()
         
-    def refine_targets(self, img_size, targets, min_box_size):
-        # rescale targets
-        for tgt in targets:
-            boxes = tgt["boxes"].clone()
-            labels = tgt["labels"].clone()
-            # refine tgt
-            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
-            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
-            keep = (min_tgt_size >= min_box_size)
 
-            tgt["boxes"] = boxes[keep]
-            tgt["labels"] = labels[keep]
+# ## Real-time PlainDETR Trainer
+# class RTPDetrTrainer(object):
+#     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
+#         # ------------------- Basic parameters -------------------
+#         self.args = args
+#         self.epoch = 0
+#         self.best_map = -1.
+#         self.device = device
+#         self.criterion = criterion
+#         self.world_size = world_size
+#         self.grad_accumulate = args.grad_accumulate
+#         self.clip_grad = 0.1
+#         self.heavy_eval = False
+#         # close AMP for RT-DETR
+#         self.args.fp16 = False
+#         # weak augmentatino stage
+#         self.second_stage = False
+#         self.second_stage_epoch = -1
+#         # path to save model
+#         self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
+#         os.makedirs(self.path_to_save, exist_ok=True)
+
+#         # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
+#         self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001, 'backbone_lr_ratio': 0.1}
+#         self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.1, 'warmup_iters': 2000} # no lr decay
+#         self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
+
+#         # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
+#         self.data_cfg  = data_cfg
+#         self.model_cfg = model_cfg
+#         self.trans_cfg = trans_cfg
+
+#         # ---------------------------- Build Transform ----------------------------
+#         self.train_transform, self.trans_cfg = build_transform(
+#             args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+#         self.val_transform, _ = build_transform(
+#             args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
+#         if self.trans_cfg["mosaic_prob"] > 0.5:
+#             self.second_stage_epoch = 5
+
+#         # ---------------------------- Build Dataset & Dataloader ----------------------------
+#         self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
+#         self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
+
+#         # ---------------------------- Build Evaluator ----------------------------
+#         self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
+
+#         # ---------------------------- Build Grad. Scaler ----------------------------
+#         self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
+
+#         # ---------------------------- Build Optimizer ----------------------------
+#         self.optimizer_dict['lr0'] *= self.args.batch_size / 16.  # auto lr scaling
+#         self.optimizer, self.start_epoch = build_rtdetr_optimizer(self.optimizer_dict, model, self.args.resume)
+
+#         # ---------------------------- Build LR Scheduler ----------------------------
+#         self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch)
+#         self.lr_scheduler.last_epoch = self.start_epoch - 1  # do not move
+#         if self.args.resume and self.args.resume != 'None':
+#             self.lr_scheduler.step()
+
+#         # ---------------------------- Build Model-EMA ----------------------------
+#         if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
+#             print('Build ModelEMA ...')
+#             self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
+#         else:
+#             self.model_ema = None
+
+#     def train(self, model):
+#         for epoch in range(self.start_epoch, self.args.max_epoch):
+#             if self.args.distributed:
+#                 self.train_loader.batch_sampler.sampler.set_epoch(epoch)
+
+#             # check second stage
+#             if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
+#                 self.check_second_stage()
+#                 # save model of the last mosaic epoch
+#                 weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
+#                 checkpoint_path = os.path.join(self.path_to_save, weight_name)
+#                 print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
+#                 torch.save({'model': model.state_dict(),
+#                             'mAP': round(self.evaluator.map*100, 1),
+#                             'optimizer': self.optimizer.state_dict(),
+#                             'epoch': self.epoch,
+#                             'args': self.args}, 
+#                             checkpoint_path)
+
+#             # train one epoch
+#             self.epoch = epoch
+#             self.train_one_epoch(model)
+
+#             # eval one epoch
+#             if self.heavy_eval:
+#                 model_eval = model.module if self.args.distributed else model
+#                 self.eval(model_eval)
+#             else:
+#                 model_eval = model.module if self.args.distributed else model
+#                 if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
+#                     self.eval(model_eval)
+
+#             if self.args.debug:
+#                 print("For debug mode, we only train 1 epoch")
+#                 break
+
+#     def eval(self, model):
+#         # chech model
+#         model_eval = model if self.model_ema is None else self.model_ema.ema
+
+#         if distributed_utils.is_main_process():
+#             # check evaluator
+#             if self.evaluator is None:
+#                 print('No evaluator ... save model and go on training.')
+#                 print('Saving state, epoch: {}'.format(self.epoch))
+#                 weight_name = '{}_no_eval.pth'.format(self.args.model)
+#                 checkpoint_path = os.path.join(self.path_to_save, weight_name)
+#                 torch.save({'model': model_eval.state_dict(),
+#                             'mAP': -1.,
+#                             'optimizer': self.optimizer.state_dict(),
+#                             'epoch': self.epoch,
+#                             'args': self.args}, 
+#                             checkpoint_path)               
+#             else:
+#                 print('eval ...')
+#                 # set eval mode
+#                 model_eval.eval()
+
+#                 # evaluate
+#                 with torch.no_grad():
+#                     self.evaluator.evaluate(model_eval)
+
+#                 # save model
+#                 cur_map = self.evaluator.map
+#                 if cur_map > self.best_map:
+#                     # update best-map
+#                     self.best_map = cur_map
+#                     # save model
+#                     print('Saving state, epoch:', self.epoch)
+#                     weight_name = '{}_best.pth'.format(self.args.model)
+#                     checkpoint_path = os.path.join(self.path_to_save, weight_name)
+#                     torch.save({'model': model_eval.state_dict(),
+#                                 'mAP': round(self.best_map*100, 1),
+#                                 'optimizer': self.optimizer.state_dict(),
+#                                 'epoch': self.epoch,
+#                                 'args': self.args}, 
+#                                 checkpoint_path)                      
+
+#                 # set train mode.
+#                 model_eval.train()
+
+#         if self.args.distributed:
+#             # wait for all processes to synchronize
+#             dist.barrier()
+
+#     def train_one_epoch(self, model):
+#         metric_logger = MetricLogger(delimiter="  ")
+#         metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
+#         metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
+#         metric_logger.add_meter('grad_norm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
+#         header = 'Epoch: [{} / {}]'.format(self.epoch, self.args.max_epoch)
+#         epoch_size = len(self.train_loader)
+#         print_freq = 10
+
+#         # basic parameters
+#         epoch_size = len(self.train_loader)
+#         img_size = self.args.img_size
+#         nw = self.lr_schedule_dict['warmup_iters']
+
+#         # Train one epoch
+#         for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
+#             ni = iter_i + self.epoch * epoch_size
+#             # Warmup
+#             if ni <= nw:
+#                 xi = [0, nw]  # x interp
+#                 for x in self.optimizer.param_groups:
+#                     x['lr'] = np.interp(ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
+                                
+#             # To device
+#             images = images.to(self.device, non_blocking=True).float()
+
+#             # Multi scale
+#             if self.args.multi_scale:
+#                 images, targets, img_size = self.rescale_image_targets(
+#                     images, targets, self.model_cfg['max_stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
+#             else:
+#                 targets = self.refine_targets(targets, self.args.min_box_size)
+
+#             # xyxy -> cxcywh
+#             targets = self.box_xyxy_to_cxcywh(targets)
+                
+#             # Visualize train targets
+#             if self.args.vis_tgt:
+#                 targets = self.box_cxcywh_to_xyxy(targets)
+#                 vis_data(images, targets, pixel_mean=self.trans_cfg['pixel_mean'], pixel_std=self.trans_cfg['pixel_std'])
+#                 targets = self.box_xyxy_to_cxcywh(targets)
+
+#             # Inference
+#             with torch.cuda.amp.autocast(enabled=self.args.fp16):
+#                 outputs = model(images)
+#                 # Compute loss
+#                 loss_dict = self.criterion(outputs, targets)
+#                 loss_weight_dict = self.criterion.weight_dict
+#                 losses = sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict)
+
+#                 # Grad Accumulate
+#                 if self.grad_accumulate > 1:
+#                     losses /= self.grad_accumulate
+
+#                 # Reduce losses over all GPUs for logging purposes
+#                 loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
+#                 loss_dict_reduced_scaled = {k: v * loss_weight_dict[k] for k, v in loss_dict_reduced.items() if k in loss_weight_dict}
+#                 losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
+#                 loss_value = losses_reduced_scaled.item()
+
+#             # Backward
+#             self.scaler.scale(losses).backward()
+
+#             # Optimize
+#             if ni % self.grad_accumulate == 0:
+#                 grad_norm = None
+#                 if self.clip_grad > 0:
+#                     # unscale gradients
+#                     self.scaler.unscale_(self.optimizer)
+#                     # clip gradients
+#                     grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
+#                 # optimizer.step
+#                 self.scaler.step(self.optimizer)
+#                 self.scaler.update()
+#                 self.optimizer.zero_grad()
+#                 # ema
+#                 if self.model_ema is not None:
+#                     self.model_ema.update(model)
+
+#             # Update log
+#             metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled)
+#             metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
+#             metric_logger.update(grad_norm=grad_norm)
+#             metric_logger.update(size=img_size)
+
+#             if self.args.debug:
+#                 print("For debug mode, we only train 1 iteration")
+#                 break
+
+#         # LR Schedule
+#         if not self.second_stage:
+#             self.lr_scheduler.step()
         
-        return targets
-
-    def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
-        """
-            Deployed for Multi scale trick.
-        """
-        if isinstance(stride, int):
-            max_stride = stride
-        elif isinstance(stride, list):
-            max_stride = max(stride)
-
-        # During training phase, the shape of input image is square.
-        old_img_size = images.shape[-1]
-        new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
-        new_img_size = new_img_size // max_stride * max_stride  # size
-        if new_img_size / old_img_size != 1:
-            # interpolate
-            images = torch.nn.functional.interpolate(
-                                input=images, 
-                                size=new_img_size, 
-                                mode='bilinear', 
-                                align_corners=False)
-        # rescale targets
-        for tgt in targets:
-            boxes = tgt["boxes"].clone()
-            labels = tgt["labels"].clone()
-            boxes = torch.clamp(boxes, 0, old_img_size)
-            # rescale box
-            boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
-            boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
-            # refine tgt
-            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
-            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
-            keep = (min_tgt_size >= min_box_size)
-
-            tgt["boxes"] = boxes[keep]
-            tgt["labels"] = labels[keep]
-
-        return images, targets, new_img_size
-
-    def box_xyxy_to_cxcywh(self, targets):
-        # rescale targets
-        for tgt in targets:
-            boxes_xyxy = tgt["boxes"].clone()
-            # rescale box
-            cxcy = (boxes_xyxy[..., :2] + boxes_xyxy[..., 2:]) * 0.5
-            bwbh = boxes_xyxy[..., 2:] - boxes_xyxy[..., :2]
-            boxes_bwbh = torch.cat([cxcy, bwbh], dim=-1)
-
-            tgt["boxes"] = boxes_bwbh
-
-        return targets
-
-    def box_cxcywh_to_xyxy(self, targets):
-        # rescale targets
-        for tgt in targets:
-            boxes_cxcywh = tgt["boxes"].clone()
-            # rescale box
-            x1y1 = boxes_cxcywh[..., :2] - boxes_cxcywh[..., 2:] * 0.5
-            x2y2 = boxes_cxcywh[..., :2] + boxes_cxcywh[..., 2:] * 0.5
-            boxes_bwbh = torch.cat([x1y1, x2y2], dim=-1)
-
-            tgt["boxes"] = boxes_bwbh
-
-        return targets
-
-    def check_second_stage(self):
-        # set second stage
-        print('============== Second stage of Training ==============')
-        self.second_stage = True
-
-        # close mosaic augmentation
-        if self.train_loader.dataset.mosaic_prob > 0.:
-            print(' - Close < Mosaic Augmentation > ...')
-            self.train_loader.dataset.mosaic_prob = 0.
-            self.heavy_eval = True
-
-        # close mixup augmentation
-        if self.train_loader.dataset.mixup_prob > 0.:
-            print(' - Close < Mixup Augmentation > ...')
-            self.train_loader.dataset.mixup_prob = 0.
-            self.heavy_eval = True
-
-        # close rotation augmentation
-        if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
-            print(' - Close < degress of rotation > ...')
-            self.trans_cfg['degrees'] = 0.0
-        if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
-            print(' - Close < shear of rotation >...')
-            self.trans_cfg['shear'] = 0.0
-        if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
-            print(' - Close < perspective of rotation > ...')
-            self.trans_cfg['perspective'] = 0.0
-
-        # build a new transform for second stage
-        print(' - Rebuild transforms ...')
-        self.train_transform, self.trans_cfg = build_transform(
-            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+#     def refine_targets(self, targets, min_box_size):
+#         # rescale targets
+#         for tgt in targets:
+#             boxes = tgt["boxes"].clone()
+#             labels = tgt["labels"].clone()
+#             # refine tgt
+#             tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+#             min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+#             keep = (min_tgt_size >= min_box_size)
+
+#             tgt["boxes"] = boxes[keep]
+#             tgt["labels"] = labels[keep]
         
-        self.train_transform.set_weak_augment()
-        self.train_loader.dataset.transform = self.train_transform
+#         return targets
+
+#     def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
+#         """
+#             Deployed for Multi scale trick.
+#         """
+#         if isinstance(stride, int):
+#             max_stride = stride
+#         elif isinstance(stride, list):
+#             max_stride = max(stride)
+
+#         # During training phase, the shape of input image is square.
+#         old_img_size = images.shape[-1]
+#         new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
+#         new_img_size = new_img_size // max_stride * max_stride  # size
+#         if new_img_size / old_img_size != 1:
+#             # interpolate
+#             images = torch.nn.functional.interpolate(
+#                                 input=images, 
+#                                 size=new_img_size, 
+#                                 mode='bilinear', 
+#                                 align_corners=False)
+#         # rescale targets
+#         for tgt in targets:
+#             boxes = tgt["boxes"].clone()
+#             labels = tgt["labels"].clone()
+#             boxes = torch.clamp(boxes, 0, old_img_size)
+#             # rescale box
+#             boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
+#             boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
+#             # refine tgt
+#             tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+#             min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+#             keep = (min_tgt_size >= min_box_size)
+
+#             tgt["boxes"] = boxes[keep]
+#             tgt["labels"] = labels[keep]
+
+#         return images, targets, new_img_size
+
+#     def box_xyxy_to_cxcywh(self, targets):
+#         # rescale targets
+#         for tgt in targets:
+#             boxes_xyxy = tgt["boxes"].clone()
+#             # rescale box
+#             cxcy = (boxes_xyxy[..., :2] + boxes_xyxy[..., 2:]) * 0.5
+#             bwbh = boxes_xyxy[..., 2:] - boxes_xyxy[..., :2]
+#             boxes_bwbh = torch.cat([cxcy, bwbh], dim=-1)
+
+#             tgt["boxes"] = boxes_bwbh
+
+#         return targets
+
+#     def box_cxcywh_to_xyxy(self, targets):
+#         # rescale targets
+#         for tgt in targets:
+#             boxes_cxcywh = tgt["boxes"].clone()
+#             # rescale box
+#             x1y1 = boxes_cxcywh[..., :2] - boxes_cxcywh[..., 2:] * 0.5
+#             x2y2 = boxes_cxcywh[..., :2] + boxes_cxcywh[..., 2:] * 0.5
+#             boxes_bwbh = torch.cat([x1y1, x2y2], dim=-1)
+
+#             tgt["boxes"] = boxes_bwbh
+
+#         return targets
+
+#     def check_second_stage(self):
+#         # set second stage
+#         print('============== Second stage of Training ==============')
+#         self.second_stage = True
+
+#         # close mosaic augmentation
+#         if self.train_loader.dataset.mosaic_prob > 0.:
+#             print(' - Close < Mosaic Augmentation > ...')
+#             self.train_loader.dataset.mosaic_prob = 0.
+#             self.heavy_eval = True
+
+#         # close mixup augmentation
+#         if self.train_loader.dataset.mixup_prob > 0.:
+#             print(' - Close < Mixup Augmentation > ...')
+#             self.train_loader.dataset.mixup_prob = 0.
+#             self.heavy_eval = True
+
+#         # close rotation augmentation
+#         if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
+#             print(' - Close < degress of rotation > ...')
+#             self.trans_cfg['degrees'] = 0.0
+#         if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
+#             print(' - Close < shear of rotation >...')
+#             self.trans_cfg['shear'] = 0.0
+#         if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
+#             print(' - Close < perspective of rotation > ...')
+#             self.trans_cfg['perspective'] = 0.0
+
+#         # build a new transform for second stage
+#         print(' - Rebuild transforms ...')
+#         self.train_transform, self.trans_cfg = build_transform(
+#             args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        
+#         self.train_transform.set_weak_augment()
+#         self.train_loader.dataset.transform = self.train_transform
         
 
 # ----------------------- Det + Seg trainers -----------------------

+ 7 - 20
models/detectors/rtpdetr/loss.py

@@ -42,17 +42,6 @@ class Criterion(nn.Module):
         self.weight_dict = {'loss_cls':  cfg['loss_coeff']['class'],
                             'loss_box':  cfg['loss_coeff']['bbox'],
                             'loss_giou': cfg['loss_coeff']['giou']}
-        if aux_loss:
-            aux_weight_dict = {}
-            for i in range(cfg['de_num_layers'] - 1):
-                aux_weight_dict.update({k + f'_{i}': v for k, v in self.weight_dict.items()})
-            self.weight_dict.update(aux_weight_dict)
-        # ------------- One2many loss weight -------------
-        if cfg['num_queries_one2many'] > 0:
-            one2many_loss_weight = {}
-            for k, v in self.weight_dict.items():
-                one2many_loss_weight[k+"_one2many"] = v
-            self.weight_dict.update(one2many_loss_weight)
 
     def _get_src_permutation_idx(self, indices):
         # permute predictions following indices
@@ -163,9 +152,9 @@ class Criterion(nn.Module):
         losses = {}
         for loss in self.losses:
             kwargs = {}
-            losses.update(
-                self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)
-            )
+            l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)
+            l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
+            losses.update(l_dict)
 
         # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
         if "aux_outputs" in outputs:
@@ -173,9 +162,8 @@ class Criterion(nn.Module):
                 indices = self.matcher(aux_outputs, targets)
                 for loss in self.losses:
                     kwargs = {}
-                    l_dict = self.get_loss(
-                        loss, aux_outputs, targets, indices, num_boxes, **kwargs
-                    )
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
+                    l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
                     l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                     losses.update(l_dict)
 
@@ -187,9 +175,8 @@ class Criterion(nn.Module):
             indices = self.matcher(enc_outputs, bin_targets)
             for loss in self.losses:
                 kwargs = {}
-                l_dict = self.get_loss(
-                    loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs
-                )
+                l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
+                l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
                 l_dict = {k + "_enc": v for k, v in l_dict.items()}
                 losses.update(l_dict)
 

+ 1 - 1
models/detectors/rtpdetr/rtpdetr.py

@@ -362,7 +362,7 @@ if __name__ == '__main__':
         'freeze_stem_only': False,
         'hidden_dim': 256,
         'en_num_heads': 8,
-        'en_num_layers': 1,
+        'en_num_layers': 6,
         'en_mlp_ratio': 4.0,
         'en_dropout': 0.0,
         'en_act': 'gelu',

+ 2 - 2
models/detectors/rtpdetr/rtpdetr_decoder.py

@@ -270,8 +270,8 @@ class PlainDETRTransformer(nn.Module):
 
         # Prepare input for decoder
         memory = src_flatten
-        bs, _, c = memory.shape
-       
+        bs, seq_l, c = memory.shape
+
         # Two stage trick
         if self.training:
             self.two_stage_num_proposals = self.num_queries_one2one + self.num_queries_one2many