|
|
@@ -29,14 +29,17 @@ class Yolov8Trainer(object):
|
|
|
self.args = args
|
|
|
self.epoch = 0
|
|
|
self.best_map = -1.
|
|
|
- self.last_opt_step = 0
|
|
|
- self.no_aug_epoch = args.no_aug_epoch
|
|
|
- self.clip_grad = 10
|
|
|
self.device = device
|
|
|
self.criterion = criterion
|
|
|
self.world_size = world_size
|
|
|
self.heavy_eval = False
|
|
|
+ self.last_opt_step = 0
|
|
|
+ self.clip_grad = 10
|
|
|
+ # weak augmentatino stage
|
|
|
self.second_stage = False
|
|
|
+ self.third_stage = False
|
|
|
+ self.second_stage_epoch = args.no_aug_epoch
|
|
|
+ self.third_stage_epoch = args.no_aug_epoch // 2
|
|
|
# path to save model
|
|
|
self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
|
|
|
os.makedirs(self.path_to_save, exist_ok=True)
|
|
|
@@ -87,26 +90,38 @@ class Yolov8Trainer(object):
|
|
|
else:
|
|
|
self.model_ema = None
|
|
|
|
|
|
-
|
|
|
def train(self, model):
|
|
|
for epoch in range(self.start_epoch, self.args.max_epoch):
|
|
|
if self.args.distributed:
|
|
|
self.train_loader.batch_sampler.sampler.set_epoch(epoch)
|
|
|
|
|
|
# check second stage
|
|
|
- if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
|
|
|
+ if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
|
|
|
self.check_second_stage()
|
|
|
# save model of the last mosaic epoch
|
|
|
weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
|
|
|
checkpoint_path = os.path.join(self.path_to_save, weight_name)
|
|
|
- if not os.path.exists(checkpoint_path):
|
|
|
- print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
|
|
|
- torch.save({'model': model.state_dict(),
|
|
|
- 'mAP': round(self.evaluator.map*100, 1),
|
|
|
- 'optimizer': self.optimizer.state_dict(),
|
|
|
- 'epoch': self.epoch,
|
|
|
- 'args': self.args},
|
|
|
- checkpoint_path)
|
|
|
+ print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
|
|
|
+ torch.save({'model': model.state_dict(),
|
|
|
+ 'mAP': round(self.evaluator.map*100, 1),
|
|
|
+ 'optimizer': self.optimizer.state_dict(),
|
|
|
+ 'epoch': self.epoch,
|
|
|
+ 'args': self.args},
|
|
|
+ checkpoint_path)
|
|
|
+
|
|
|
+ # check third stage
|
|
|
+ if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
|
|
|
+ self.check_third_stage()
|
|
|
+ # save model of the last mosaic epoch
|
|
|
+ weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
|
|
|
+ checkpoint_path = os.path.join(self.path_to_save, weight_name)
|
|
|
+ print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
|
|
|
+ torch.save({'model': model.state_dict(),
|
|
|
+ 'mAP': round(self.evaluator.map*100, 1),
|
|
|
+ 'optimizer': self.optimizer.state_dict(),
|
|
|
+ 'epoch': self.epoch,
|
|
|
+ 'args': self.args},
|
|
|
+ checkpoint_path)
|
|
|
|
|
|
# train one epoch
|
|
|
self.epoch = epoch
|
|
|
@@ -121,7 +136,6 @@ class Yolov8Trainer(object):
|
|
|
if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
|
|
|
self.eval(model_eval)
|
|
|
|
|
|
-
|
|
|
def eval(self, model):
|
|
|
# chech model
|
|
|
model_eval = model if self.model_ema is None else self.model_ema.ema
|
|
|
@@ -173,7 +187,6 @@ class Yolov8Trainer(object):
|
|
|
# wait for all processes to synchronize
|
|
|
dist.barrier()
|
|
|
|
|
|
-
|
|
|
def train_one_epoch(self, model):
|
|
|
# basic parameters
|
|
|
epoch_size = len(self.train_loader)
|
|
|
@@ -266,7 +279,6 @@ class Yolov8Trainer(object):
|
|
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|
|
-
|
|
|
def check_second_stage(self):
|
|
|
# set second stage
|
|
|
print('============== Second stage of Training ==============')
|
|
|
@@ -295,6 +307,17 @@ class Yolov8Trainer(object):
|
|
|
print(' - Close < perspective of rotation > ...')
|
|
|
self.trans_cfg['perspective'] = 0.0
|
|
|
|
|
|
+ # build a new transform for second stage
|
|
|
+ print(' - Rebuild transforms ...')
|
|
|
+ self.train_transform, self.trans_cfg = build_transform(
|
|
|
+ args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
|
|
|
+ self.train_loader.dataset.transform = self.train_transform
|
|
|
+
|
|
|
+ def check_third_stage(self):
|
|
|
+ # set third stage
|
|
|
+ print('============== Third stage of Training ==============')
|
|
|
+ self.third_stage = True
|
|
|
+
|
|
|
# close random affine
|
|
|
if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
|
|
|
print(' - Close < translate of affine > ...')
|
|
|
@@ -309,7 +332,6 @@ class Yolov8Trainer(object):
|
|
|
args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
|
|
|
self.train_loader.dataset.transform = self.train_transform
|
|
|
|
|
|
-
|
|
|
def refine_targets(self, targets, min_box_size):
|
|
|
# rescale targets
|
|
|
for tgt in targets:
|
|
|
@@ -325,7 +347,6 @@ class Yolov8Trainer(object):
|
|
|
|
|
|
return targets
|
|
|
|
|
|
-
|
|
|
def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
|
|
|
"""
|
|
|
Deployed for Multi scale trick.
|