|
@@ -2,6 +2,7 @@ import torch
|
|
|
import torch.distributed as dist
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
import os
|
|
import os
|
|
|
|
|
+import numpy as np
|
|
|
import random
|
|
import random
|
|
|
|
|
|
|
|
# ----------------- Extra Components -----------------
|
|
# ----------------- Extra Components -----------------
|
|
@@ -11,7 +12,7 @@ from utils.vis_tools import vis_data
|
|
|
|
|
|
|
|
# ----------------- Optimizer & LrScheduler Components -----------------
|
|
# ----------------- Optimizer & LrScheduler Components -----------------
|
|
|
from utils.solver.optimizer import build_yolo_optimizer, build_rtdetr_optimizer
|
|
from utils.solver.optimizer import build_yolo_optimizer, build_rtdetr_optimizer
|
|
|
-from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler
|
|
|
|
|
|
|
+from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler, build_lambda_lr_scheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
class YoloTrainer(object):
|
|
class YoloTrainer(object):
|
|
@@ -67,9 +68,10 @@ class YoloTrainer(object):
|
|
|
self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
|
|
self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
|
|
|
|
|
|
|
|
# ---------------------------- Build LR Scheduler ----------------------------
|
|
# ---------------------------- Build LR Scheduler ----------------------------
|
|
|
- warmup_iters = cfg.warmup_epoch * len(self.train_loader)
|
|
|
|
|
- self.lr_scheduler_warmup = LinearWarmUpLrScheduler(warmup_iters, cfg.base_lr, cfg.warmup_bias_lr, cfg.warmup_momentum)
|
|
|
|
|
- self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
|
|
|
|
|
|
|
+ self.lr_scheduler, self.lf = build_lambda_lr_scheduler(cfg, self.optimizer, cfg.max_epoch)
|
|
|
|
|
+ self.lr_scheduler.last_epoch = self.start_epoch - 1 # do not move
|
|
|
|
|
+ if self.args.resume and self.args.resume != 'None':
|
|
|
|
|
+ self.lr_scheduler.step()
|
|
|
|
|
|
|
|
def train(self, model):
|
|
def train(self, model):
|
|
|
for epoch in range(self.start_epoch, self.cfg.max_epoch):
|
|
for epoch in range(self.start_epoch, self.cfg.max_epoch):
|
|
@@ -95,8 +97,7 @@ class YoloTrainer(object):
|
|
|
self.train_one_epoch(model)
|
|
self.train_one_epoch(model)
|
|
|
|
|
|
|
|
# LR Schedule
|
|
# LR Schedule
|
|
|
- if (epoch + 1) > self.cfg.warmup_epoch:
|
|
|
|
|
- self.lr_scheduler.step()
|
|
|
|
|
|
|
+ self.lr_scheduler.step()
|
|
|
|
|
|
|
|
# eval one epoch
|
|
# eval one epoch
|
|
|
if self.heavy_eval:
|
|
if self.heavy_eval:
|
|
@@ -145,7 +146,6 @@ class YoloTrainer(object):
|
|
|
'model': model_eval.state_dict(),
|
|
'model': model_eval.state_dict(),
|
|
|
'mAP': round(cur_map*100, 1),
|
|
'mAP': round(cur_map*100, 1),
|
|
|
'optimizer': self.optimizer.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
|
- 'lr_scheduler': self.lr_scheduler.state_dict(),
|
|
|
|
|
'epoch': self.epoch,
|
|
'epoch': self.epoch,
|
|
|
'args': self.args,
|
|
'args': self.args,
|
|
|
}
|
|
}
|
|
@@ -177,11 +177,14 @@ class YoloTrainer(object):
|
|
|
for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
|
|
for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
|
|
|
ni = iter_i + self.epoch * epoch_size
|
|
ni = iter_i + self.epoch * epoch_size
|
|
|
# Warmup
|
|
# Warmup
|
|
|
- if nw > 0 and ni < nw:
|
|
|
|
|
- self.lr_scheduler_warmup(ni, self.optimizer)
|
|
|
|
|
- elif ni == nw:
|
|
|
|
|
- print("Warmup stage is over.")
|
|
|
|
|
- self.lr_scheduler_warmup.set_lr(self.optimizer, self.cfg.base_lr)
|
|
|
|
|
|
|
+ if ni <= nw:
|
|
|
|
|
+ xi = [0, nw] # x interp
|
|
|
|
|
+ for j, x in enumerate(self.optimizer.param_groups):
|
|
|
|
|
+ # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
|
|
|
|
+ x['lr'] = np.interp(
|
|
|
|
|
+ ni, xi, [self.cfg.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
|
|
|
|
|
+ if 'momentum' in x:
|
|
|
|
|
+ x['momentum'] = np.interp(ni, xi, [self.cfg.warmup_momentum, self.cfg.momentum])
|
|
|
|
|
|
|
|
# To device
|
|
# To device
|
|
|
images = images.to(self.device, non_blocking=True).float()
|
|
images = images.to(self.device, non_blocking=True).float()
|