소스 검색

modify RTPDetrTrainer

yjh0410 1 년 전
부모
커밋
d09e797abd
1개의 변경된 파일20개의 추가작업 그리고 16개의 파일을 삭제
  1. 20 16
      engine.py

+ 20 - 16
engine.py

@@ -1494,8 +1494,9 @@ class RTPDetrTrainer(RTDetrTrainer):
         super().__init__(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
         # ------------------- Basic parameters -------------------
         ## Reset optimzier hyper-parameters
-        self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 0.05, 'lr0': 0.0002, 'backbone_lr_ratio': 0.1}
-        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 1.0, 'warmup_iters': 1000}
+        self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 0.0001, 'lr0': 0.0001, 'backbone_lr_ratio': 0.1}
+        self.warmup_dict = {'warmup': 'linear', 'warmup_iters': 2000, 'warmup_factor': 0.00066667}
+        self.lr_schedule_dict = {'lr_scheduler': 'step', 'lr_epoch': [self.args.max_epoch // 12 * 11]}
         self.normalize_bbox = False
 
         # ---------------------------- Build Optimizer ----------------------------
@@ -1505,10 +1506,8 @@ class RTPDetrTrainer(RTDetrTrainer):
 
         # ---------------------------- Build LR Scheduler ----------------------------
         print("- Re-build lr scheduler -")
-        self.lr_scheduler, self.lf = build_lambda_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch)
-        self.lr_scheduler.last_epoch = self.start_epoch - 1  # do not move
-        if self.args.resume and self.args.resume != 'None':
-            self.lr_scheduler.step()
+        self.wp_lr_scheduler = build_wp_lr_scheduler(self.warmup_dict, self.optimizer_dict['lr0'])
+        self.lr_scheduler    = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.resume)
 
     def train_one_epoch(self, model):
         metric_logger = MetricLogger(delimiter="  ")
@@ -1522,19 +1521,25 @@ class RTPDetrTrainer(RTDetrTrainer):
         # basic parameters
         epoch_size = len(self.train_loader)
         img_size = self.args.img_size
-        nw = self.lr_schedule_dict['warmup_iters']
+        nw = self.warmup_dict['warmup_iters']
+        lr_warmup_stage = True
 
         # Train one epoch
         for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
             ni = iter_i + self.epoch * epoch_size
-            # Warmup
-            if ni <= nw:
-                xi = [0, nw]  # x interp
-                for x in self.optimizer.param_groups:
-                    x['lr'] = np.interp(ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
-                                
+            # WarmUp
+            if ni < nw and lr_warmup_stage:
+                self.wp_lr_scheduler(ni, self.optimizer)
+            elif ni == nw and lr_warmup_stage:
+                print('Warmup stage is over.')
+                lr_warmup_stage = False
+                self.wp_lr_scheduler.set_lr(self.optimizer, self.optimizer_dict['lr0'], self.optimizer_dict['lr0'])
+                                            
             # To device
             images = images.to(self.device, non_blocking=True).float()
+            for tgt in targets:
+                tgt['boxes'] = tgt['boxes'].to(self.device)
+                tgt['labels'] = tgt['labels'].to(self.device)
 
             # Multi scale
             if self.args.multi_scale:
@@ -1585,7 +1590,7 @@ class RTPDetrTrainer(RTDetrTrainer):
 
             # Update log
             metric_logger.update(loss=losses.item(), **loss_dict_reduced)
-            metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
+            metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])
             metric_logger.update(grad_norm=grad_norm)
             metric_logger.update(size=img_size)
 
@@ -1594,8 +1599,7 @@ class RTPDetrTrainer(RTDetrTrainer):
                 break
 
         # LR Schedule
-        if not self.second_stage:
-            self.lr_scheduler.step()
+        self.lr_scheduler.step()
         
 
 # Build Trainer