yjh0410 2 лет назад
Родитель
Сommit
78e820fb04
2 измененных файлов с 69 добавлено и 48 удалено
  1. 39 18
      engine.py
  2. 30 30
      train_multi_gpus.sh

+ 39 - 18
engine.py

@@ -29,14 +29,17 @@ class Yolov8Trainer(object):
         self.args = args
         self.epoch = 0
         self.best_map = -1.
-        self.last_opt_step = 0
-        self.no_aug_epoch = args.no_aug_epoch
-        self.clip_grad = 10
         self.device = device
         self.criterion = criterion
         self.world_size = world_size
         self.heavy_eval = False
+        self.last_opt_step = 0
+        self.clip_grad = 10
+        # weak augmentatino stage
         self.second_stage = False
+        self.third_stage = False
+        self.second_stage_epoch = args.no_aug_epoch
+        self.third_stage_epoch = args.no_aug_epoch // 2
         # path to save model
         self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
         os.makedirs(self.path_to_save, exist_ok=True)
@@ -87,26 +90,38 @@ class Yolov8Trainer(object):
         else:
             self.model_ema = None
 
-
     def train(self, model):
         for epoch in range(self.start_epoch, self.args.max_epoch):
             if self.args.distributed:
                 self.train_loader.batch_sampler.sampler.set_epoch(epoch)
 
             # check second stage
-            if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
+            if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
                 self.check_second_stage()
                 # save model of the last mosaic epoch
                 weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
                 checkpoint_path = os.path.join(self.path_to_save, weight_name)
-                if not os.path.exists(checkpoint_path):
-                    print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
-                    torch.save({'model': model.state_dict(),
-                                'mAP': round(self.evaluator.map*100, 1),
-                                'optimizer': self.optimizer.state_dict(),
-                                'epoch': self.epoch,
-                                'args': self.args}, 
-                                checkpoint_path)                      
+                print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
+                torch.save({'model': model.state_dict(),
+                            'mAP': round(self.evaluator.map*100, 1),
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)
+
+            # check third stage
+            if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
+                self.check_third_stage()
+                # save model of the last mosaic epoch
+                weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
+                torch.save({'model': model.state_dict(),
+                            'mAP': round(self.evaluator.map*100, 1),
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)
 
             # train one epoch
             self.epoch = epoch
@@ -121,7 +136,6 @@ class Yolov8Trainer(object):
                 if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
                     self.eval(model_eval)
 
-
     def eval(self, model):
         # chech model
         model_eval = model if self.model_ema is None else self.model_ema.ema
@@ -173,7 +187,6 @@ class Yolov8Trainer(object):
             # wait for all processes to synchronize
             dist.barrier()
 
-
     def train_one_epoch(self, model):
         # basic parameters
         epoch_size = len(self.train_loader)
@@ -266,7 +279,6 @@ class Yolov8Trainer(object):
         
         self.lr_scheduler.step()
         
-
     def check_second_stage(self):
         # set second stage
         print('============== Second stage of Training ==============')
@@ -295,6 +307,17 @@ class Yolov8Trainer(object):
             print(' - Close < perspective of rotation > ...')
             self.trans_cfg['perspective'] = 0.0
 
+        # build a new transform for second stage
+        print(' - Rebuild transforms ...')
+        self.train_transform, self.trans_cfg = build_transform(
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.train_loader.dataset.transform = self.train_transform
+        
+    def check_third_stage(self):
+        # set third stage
+        print('============== Third stage of Training ==============')
+        self.third_stage = True
+
         # close random affine
         if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
             print(' - Close < translate of affine > ...')
@@ -309,7 +332,6 @@ class Yolov8Trainer(object):
             args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.train_loader.dataset.transform = self.train_transform
         
-
     def refine_targets(self, targets, min_box_size):
         # rescale targets
         for tgt in targets:
@@ -325,7 +347,6 @@ class Yolov8Trainer(object):
         
         return targets
 
-
     def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
         """
             Deployed for Multi scale trick.

+ 30 - 30
train_multi_gpus.sh

@@ -1,60 +1,60 @@
 # -------------------------- Train RTCDet --------------------------
-python -m torch.distributed.run --nproc_per_node=8 train.py \
-                                                    --cuda \
-                                                    -dist \
-                                                    -d coco \
-                                                    --root /data/datasets/ \
-                                                    -m rtcdet_s \
-                                                    -bs 128 \
-                                                    -size 640 \
-                                                    --wp_epoch 3 \
-                                                    --max_epoch 300 \
-                                                    --eval_epoch 10 \
-                                                    --no_aug_epoch 20 \
-                                                    --ema \
-                                                    --fp16 \
-                                                    --sybn \
-                                                    --multi_scale \
-                                                    --save_folder weights/ \
-                                                    # --load_cache \
-                                                    # --resume weights/coco/yolox_l/yolox_l_best.pth \
-
-# -------------------------- Train YOLOX & YOLOv7 --------------------------
 # python -m torch.distributed.run --nproc_per_node=8 train.py \
 #                                                     --cuda \
 #                                                     -dist \
 #                                                     -d coco \
 #                                                     --root /data/datasets/ \
-#                                                     -m rtcdet_n \
-#                                                     -bs 64 \
+#                                                     -m rtcdet_s \
+#                                                     -bs 128 \
 #                                                     -size 640 \
 #                                                     --wp_epoch 3 \
 #                                                     --max_epoch 300 \
 #                                                     --eval_epoch 10 \
-#                                                     --no_aug_epoch 15 \
+#                                                     --no_aug_epoch 20 \
 #                                                     --ema \
 #                                                     --fp16 \
 #                                                     --sybn \
 #                                                     --multi_scale \
+#                                                     --save_folder weights/ \
 #                                                     # --load_cache \
 #                                                     # --resume weights/coco/yolox_l/yolox_l_best.pth \
 
-# -------------------------- Train YOLOv1~v5 --------------------------
+# -------------------------- Train YOLOX & YOLOv7 --------------------------
 # python -m torch.distributed.run --nproc_per_node=8 train.py \
 #                                                     --cuda \
 #                                                     -dist \
 #                                                     -d coco \
 #                                                     --root /data/datasets/ \
-#                                                     -m yolov5_l\
-#                                                     -bs 128 \
+#                                                     -m rtcdet_n \
+#                                                     -bs 64 \
 #                                                     -size 640 \
 #                                                     --wp_epoch 3 \
 #                                                     --max_epoch 300 \
 #                                                     --eval_epoch 10 \
-#                                                     --no_aug_epoch 10 \
+#                                                     --no_aug_epoch 15 \
 #                                                     --ema \
 #                                                     --fp16 \
 #                                                     --sybn \
 #                                                     --multi_scale \
-#                                                     # --load_cache
-#                                                     # --resume weights/coco/yolov5_l/yolov5_l_best.pth \
+#                                                     # --load_cache \
+#                                                     # --resume weights/coco/yolox_l/yolox_l_best.pth \
+
+# -------------------------- Train YOLOv1~v5 --------------------------
+python -m torch.distributed.run --nproc_per_node=8 train.py \
+                                                    --cuda \
+                                                    -dist \
+                                                    -d coco \
+                                                    --root /data/datasets/ \
+                                                    -m yolov8_n\
+                                                    -bs 128 \
+                                                    -size 640 \
+                                                    --wp_epoch 3 \
+                                                    --max_epoch 500 \
+                                                    --eval_epoch 10 \
+                                                    --no_aug_epoch 10 \
+                                                    --ema \
+                                                    --fp16 \
+                                                    --sybn \
+                                                    --multi_scale \
+                                                    # --load_cache
+                                                    # --resume weights/coco/yolov5_l/yolov5_l_best.pth \