فهرست منبع

use AdamW to train my YOLOv8

yjh0410 2 سال پیش
والد
کامیت
7ac6145b19
2فایلهای تغییر یافته به همراه2 افزوده شده و 9 حذف شده
  1. 1 8
      engine.py
  2. 1 1
      train_multi_gpus.sh

+ 1 - 8
engine.py

@@ -815,7 +815,6 @@ class RTCTrainer(object):
         else:
             self.model_ema = None
 
-
     def train(self, model):
         for epoch in range(self.start_epoch, self.args.max_epoch):
             if self.args.distributed:
@@ -862,7 +861,6 @@ class RTCTrainer(object):
                 if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
                     self.eval(model_eval)
 
-
     def eval(self, model):
         # chech model
         model_eval = model if self.model_ema is None else self.model_ema.ema
@@ -914,7 +912,6 @@ class RTCTrainer(object):
             # wait for all processes to synchronize
             dist.barrier()
 
-
     def train_one_epoch(self, model):
         # basic parameters
         epoch_size = len(self.train_loader)
@@ -1007,7 +1004,6 @@ class RTCTrainer(object):
         # LR Schedule
         if not self.second_stage:
             self.lr_scheduler.step()
-        
 
     def refine_targets(self, targets, min_box_size):
         # rescale targets
@@ -1024,7 +1020,6 @@ class RTCTrainer(object):
         
         return targets
 
-
     def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
         """
             Deployed for Multi scale trick.
@@ -1063,7 +1058,6 @@ class RTCTrainer(object):
 
         return images, targets, new_img_size
 
-
     def check_second_stage(self):
         # set second stage
         print('============== Second stage of Training ==============')
@@ -1098,7 +1092,6 @@ class RTCTrainer(object):
             args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.train_loader.dataset.transform = self.train_transform
         
-
     def check_third_stage(self):
         # set third stage
         print('============== Third stage of Training ==============')
@@ -1117,7 +1110,7 @@ class RTCTrainer(object):
         self.train_transform, self.trans_cfg = build_transform(
             args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.train_loader.dataset.transform = self.train_transform
-        
+   
 
 # RTRDet Trainer
 class RTRTrainer(object):

+ 1 - 1
train_multi_gpus.sh

@@ -30,7 +30,7 @@ python -m torch.distributed.run --nproc_per_node=8 train.py \
                                                     --wp_epoch 3 \
                                                     --max_epoch 500 \
                                                     --eval_epoch 10 \
-                                                    --no_aug_epoch 10 \
+                                                    --no_aug_epoch 20 \
                                                     --ema \
                                                     --fp16 \
                                                     --sybn \