yjh0410 vor 2 Jahren
Ursprung
Commit
55f473645f
2 geänderte Dateien mit 29 neuen und 7 gelöschten Zeilen
  1. 25 7
      engine.py
  2. 4 0
      train.py

+ 25 - 7
engine.py

@@ -136,6 +136,10 @@ class Yolov8Trainer(object):
                 if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
                     self.eval(model_eval)
 
+            if self.args.debug:
+                print("For debug mode, we only train 1 epoch")
+                break
+
     def eval(self, model):
         # chech model
         model_eval = model if self.model_ema is None else self.model_ema.ema
@@ -277,6 +281,10 @@ class Yolov8Trainer(object):
                 
                 t0 = time.time()
         
+            if self.args.debug:
+                print("For debug mode, we only train 1 iteration")
+                break
+
         self.lr_scheduler.step()
         
     def check_second_stage(self):
@@ -499,6 +507,9 @@ class YoloxTrainer(object):
                 if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
                     self.eval(model_eval)
 
+            if self.args.debug:
+                print("For debug mode, we only train 1 epoch")
+                break
 
     def eval(self, model):
         # chech model
@@ -551,7 +562,6 @@ class YoloxTrainer(object):
             # wait for all processes to synchronize
             dist.barrier()
 
-
     def train_one_epoch(self, model):
         # basic parameters
         epoch_size = len(self.train_loader)
@@ -633,12 +643,15 @@ class YoloxTrainer(object):
                 print(log, flush=True)
                 
                 t0 = time.time()
-        
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 iteration")
+                break
+
         # LR Schedule
         if not self.second_stage:
             self.lr_scheduler.step()
         
-
     def check_second_stage(self):
         # set second stage
         print('============== Second stage of Training ==============')
@@ -673,7 +686,6 @@ class YoloxTrainer(object):
             args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.train_loader.dataset.transform = self.train_transform
         
-
     def check_third_stage(self):
         # set third stage
         print('============== Third stage of Training ==============')
@@ -693,7 +705,6 @@ class YoloxTrainer(object):
             args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.train_loader.dataset.transform = self.train_transform
         
-
     def refine_targets(self, targets, min_box_size):
         # rescale targets
         for tgt in targets:
@@ -709,7 +720,6 @@ class YoloxTrainer(object):
         
         return targets
 
-
     def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
         """
             Deployed for Multi scale trick.
@@ -861,6 +871,10 @@ class RTCTrainer(object):
                 if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
                     self.eval(model_eval)
 
+            if self.args.debug:
+                print("For debug mode, we only train 1 epoch")
+                break
+
     def eval(self, model):
         # chech model
         model_eval = model if self.model_ema is None else self.model_ema.ema
@@ -1000,7 +1014,11 @@ class RTCTrainer(object):
                 print(log, flush=True)
                 
                 t0 = time.time()
-        
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 iteration")
+                break
+
         # LR Schedule
         if not self.second_stage:
             self.lr_scheduler.step()

+ 4 - 0
train.py

@@ -106,6 +106,10 @@ def parse_args():
                         help='number of distributed processes')
     parser.add_argument('--sybn', action='store_true', default=False, 
                         help='use sybn.')
+    
+    # Debug mode
+    parser.add_argument('--debug', action='store_true', default=False, 
+                        help='debug mode.')
 
     return parser.parse_args()