فهرست منبع

design a more general Trainer

yjh0410 2 سال پیش
والد
کامیت
3e40e32989
2فایلهای تغییر یافته به همراه287 افزوده شده و 321 حذف شده
  1. 232 225
      engine.py
  2. 55 96
      train.py

+ 232 - 225
engine.py

@@ -3,7 +3,6 @@ import torch.distributed as dist
 
 import time
 import os
-import math
 import numpy as np
 import random
 
@@ -11,232 +10,240 @@ from utils import distributed_utils
 from utils.vis_tools import vis_data
 
 
-def refine_targets(targets, min_box_size):
-    # rescale targets
-    for tgt in targets:
-        boxes = tgt["boxes"].clone()
-        labels = tgt["labels"].clone()
-        # refine tgt
-        tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
-        min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
-        keep = (min_tgt_size >= min_box_size)
-
-        tgt["boxes"] = boxes[keep]
-        tgt["labels"] = labels[keep]
-    
-    return targets
-
-
-def rescale_image_targets(images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
-    """
-        Deployed for Multi scale trick.
-    """
-    if isinstance(stride, int):
-        max_stride = stride
-    elif isinstance(stride, list):
-        max_stride = max(stride)
-
-    # During training phase, the shape of input image is square.
-    old_img_size = images.shape[-1]
-    new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
-    new_img_size = new_img_size // max_stride * max_stride  # size
-    if new_img_size / old_img_size != 1:
-        # interpolate
-        images = torch.nn.functional.interpolate(
-                            input=images, 
-                            size=new_img_size, 
-                            mode='bilinear', 
-                            align_corners=False)
-    # rescale targets
-    for tgt in targets:
-        boxes = tgt["boxes"].clone()
-        labels = tgt["labels"].clone()
-        boxes = torch.clamp(boxes, 0, old_img_size)
-        # rescale box
-        boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
-        boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
-        # refine tgt
-        tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
-        min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
-        keep = (min_tgt_size >= min_box_size)
-
-        tgt["boxes"] = boxes[keep]
-        tgt["labels"] = labels[keep]
-
-    return images, targets, new_img_size
-
-
-def train_one_epoch(epoch,
-                    total_epochs,
-                    args, 
-                    device, 
-                    ema,
-                    model,
-                    criterion,
-                    cfg, 
-                    dataloader, 
-                    optimizer,
-                    scheduler,
-                    lf,
-                    scaler,
-                    last_opt_step):
-    epoch_size = len(dataloader)
-    img_size = args.img_size
-    t0 = time.time()
-    nw = epoch_size * args.wp_epoch
-    accumulate = accumulate = max(1, round(64 / args.batch_size))
-
-    # train one epoch
-    for iter_i, (images, targets) in enumerate(dataloader):
-        ni = iter_i + epoch * epoch_size
-        # Warmup
-        if ni <= nw:
-            xi = [0, nw]  # x interp
-            accumulate = max(1, np.interp(ni, xi, [1, 64 / args.batch_size]).round())
-            for j, x in enumerate(optimizer.param_groups):
-                # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
-                x['lr'] = np.interp(
-                    ni, xi, [cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
-                if 'momentum' in x:
-                    x['momentum'] = np.interp(ni, xi, [cfg['warmup_momentum'], cfg['momentum']])
-                            
-        # to device
-        images = images.to(device, non_blocking=True).float() / 255.
-
-        # multi scale
-        if args.multi_scale:
-            images, targets, img_size = rescale_image_targets(
-                images, targets, model.stride, args.min_box_size, cfg['multi_scale'])
-        else:
-            targets = refine_targets(targets, args.min_box_size)
-            
-        # visualize train targets
-        if args.vis_tgt:
-            vis_data(images*255, targets)
-
-        # inference
-        with torch.cuda.amp.autocast(enabled=args.fp16):
-            outputs = model(images)
-            # loss
-            loss_dict = criterion(outputs, targets, epoch)
-            losses = loss_dict['losses']
-            losses *= images.shape[0]  # loss * bs
-
-            # reduce            
-            loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
-
-            if args.distributed:
-                # gradient averaged between devices in DDP mode
-                losses *= distributed_utils.get_world_size()
-
-        # check loss
-        try:
-            if torch.isnan(losses):
-                print('loss is NAN !!')
-                continue
-        except:
-            print(loss_dict)
-
-        # backward
-        scaler.scale(losses).backward()
-
-        # Optimize
-        if ni - last_opt_step >= accumulate:
-            if cfg['clip_grad'] > 0:
-                # unscale gradients
-                scaler.unscale_(optimizer)
-                # clip gradients
-                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg['clip_grad'])
-            # optimizer.step
-            scaler.step(optimizer)
-            scaler.update()
-            optimizer.zero_grad()
-            # ema
-            if ema:
-                ema.update(model)
-            last_opt_step = ni
-
-        # display
-        if distributed_utils.is_main_process() and iter_i % 10 == 0:
-            t1 = time.time()
-            cur_lr = [param_group['lr']  for param_group in optimizer.param_groups]
-            # basic infor
-            log =  '[Epoch: {}/{}]'.format(epoch+1, total_epochs)
-            log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
-            log += '[lr: {:.6f}]'.format(cur_lr[2])
-            # loss infor
-            for k in loss_dict_reduced.keys():
-                if k == 'losses' and args.distributed:
-                    world_size = distributed_utils.get_world_size()
-                    log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
-                else:
-                    log += '[{}: {:.2f}]'.format(k, loss_dict[k])
-
-            # other infor
-            log += '[time: {:.2f}]'.format(t1 - t0)
-            log += '[size: {}]'.format(img_size)
-
-            # print log infor
-            print(log, flush=True)
-            
-            t0 = time.time()
-    
-    scheduler.step()
-
-    return last_opt_step
-
-
-def val_one_epoch(args, 
-                  model, 
-                  evaluator,
-                  optimizer,
-                  epoch,
-                  best_map,
-                  path_to_save):
-    if distributed_utils.is_main_process():
-        # check evaluator
-        if evaluator is None:
-            print('No evaluator ... save model and go on training.')
-            print('Saving state, epoch: {}'.format(epoch + 1))
-            weight_name = '{}_no_eval.pth'.format(args.model)
-            checkpoint_path = os.path.join(path_to_save, weight_name)
-            torch.save({'model': model.state_dict(),
-                        'mAP': -1.,
-                        'optimizer': optimizer.state_dict(),
-                        'epoch': epoch,
-                        'args': args}, 
-                        checkpoint_path)                      
-            
-        else:
-            print('eval ...')
-            # set eval mode
-            model.trainable = False
-            model.eval()
-
-            # evaluate
-            evaluator.evaluate(model)
-
-            cur_map = evaluator.map
-            if cur_map > best_map:
-                # update best-map
-                best_map = cur_map
-                # save model
-                print('Saving state, epoch:', epoch + 1)
-                weight_name = '{}_best.pth'.format(args.model)
+
+class Trainer(object):
+    def __init__(self, args, device, cfg, model_ema, optimizer, lf, lr_scheduler, criterion, scaler):
+        # ------------------- basic parameters -------------------
+        self.args = args
+        self.cfg = cfg
+        self.device = device
+        self.epoch = 0
+        self.best_map = -1.
+        # ------------------- core modules -------------------
+        self.model_ema = model_ema
+        self.optimizer = optimizer
+        self.lf = lf
+        self.lr_scheduler = lr_scheduler
+        self.criterion = criterion
+        self.scaler = scaler
+        self.last_opt_step = 0
+
+
+    def train_one_epoch(self, model, train_loader):
+        # basic parameters
+        epoch_size = len(train_loader)
+        img_size = self.args.img_size
+        t0 = time.time()
+        nw = epoch_size * self.args.wp_epoch
+        accumulate = accumulate = max(1, round(64 / self.args.batch_size))
+
+        # train one epoch
+        for iter_i, (images, targets) in enumerate(train_loader):
+            ni = iter_i + self.epoch * epoch_size
+            # Warmup
+            if ni <= nw:
+                xi = [0, nw]  # x interp
+                accumulate = max(1, np.interp(ni, xi, [1, 64 / self.args.batch_size]).round())
+                for j, x in enumerate(self.optimizer.param_groups):
+                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
+                    x['lr'] = np.interp(
+                        ni, xi, [self.cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
+                    if 'momentum' in x:
+                        x['momentum'] = np.interp(ni, xi, [self.cfg['warmup_momentum'], self.cfg['momentum']])
+                                
+            # to device
+            images = images.to(self.device, non_blocking=True).float() / 255.
+
+            # multi scale
+            if self.args.multi_scale:
+                images, targets, img_size = self.rescale_image_targets(
+                    images, targets, model.stride, self.args.min_box_size, self.cfg['multi_scale'])
+            else:
+                targets = self.refine_targets(targets, self.args.min_box_size)
+                
+            # visualize train targets
+            if self.args.vis_tgt:
+                vis_data(images*255, targets)
+
+            # inference
+            with torch.cuda.amp.autocast(enabled=self.args.fp16):
+                outputs = model(images)
+                # loss
+                loss_dict = self.criterion(outputs=outputs, targets=targets)
+                losses = loss_dict['losses']
+                losses *= images.shape[0]  # loss * bs
+
+                # reduce            
+                loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
+
+                if self.args.distributed:
+                    # gradient averaged between devices in DDP mode
+                    losses *= distributed_utils.get_world_size()
+
+            # check loss
+            try:
+                if torch.isnan(losses):
+                    print('loss is NAN !!')
+                    continue
+            except:
+                print(loss_dict)
+
+            # backward
+            self.scaler.scale(losses).backward()
+
+            # Optimize
+            if ni - self.last_opt_step >= accumulate:
+                if self.cfg['clip_grad'] > 0:
+                    # unscale gradients
+                    self.scaler.unscale_(self.optimizer)
+                    # clip gradients
+                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg['clip_grad'])
+                # optimizer.step
+                self.scaler.step(self.optimizer)
+                self.scaler.update()
+                self.optimizer.zero_grad()
+                # ema
+                if self.model_ema is not None:
+                    self.model_ema.update(model)
+                self.last_opt_step = ni
+
+            # display
+            if distributed_utils.is_main_process() and iter_i % 10 == 0:
+                t1 = time.time()
+                cur_lr = [param_group['lr']  for param_group in self.optimizer.param_groups]
+                # basic infor
+                log =  '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
+                log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
+                log += '[lr: {:.6f}]'.format(cur_lr[2])
+                # loss infor
+                for k in loss_dict_reduced.keys():
+                    if k == 'losses' and self.args.distributed:
+                        world_size = distributed_utils.get_world_size()
+                        log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
+                    else:
+                        log += '[{}: {:.2f}]'.format(k, loss_dict[k])
+
+                # other infor
+                log += '[time: {:.2f}]'.format(t1 - t0)
+                log += '[size: {}]'.format(img_size)
+
+                # print log infor
+                print(log, flush=True)
+                
+                t0 = time.time()
+        
+        self.lr_scheduler.step()
+        self.epoch += 1
+        
+
+    @torch.no_grad()
+    def eval_one_epoch(self, model, evaluator):
+        # chech model
+        model_eval = model if self.model_ema is None else self.model_ema.ema
+
+        # path to save model
+        path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
+        os.makedirs(path_to_save, exist_ok=True)
+
+        if distributed_utils.is_main_process():
+            # check evaluator
+            if evaluator is None:
+                print('No evaluator ... save model and go on training.')
+                print('Saving state, epoch: {}'.format(self.epoch + 1))
+                weight_name = '{}_no_eval.pth'.format(self.args.model)
                 checkpoint_path = os.path.join(path_to_save, weight_name)
-                torch.save({'model': model.state_dict(),
-                            'mAP': round(best_map*100, 1),
-                            'optimizer': optimizer.state_dict(),
-                            'epoch': epoch,
-                            'args': args}, 
+                torch.save({'model': model_eval.state_dict(),
+                            'mAP': -1.,
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
                             checkpoint_path)                      
+                
+            else:
+                print('eval ...')
+                # set eval mode
+                model_eval.trainable = False
+                model_eval.eval()
 
-            # set train mode.
-            model.trainable = True
-            model.train()
+                # evaluate
+                evaluator.evaluate(model_eval)
 
-    if args.distributed:
-        # wait for all processes to synchronize
-        dist.barrier()
+                # save model
+                cur_map = evaluator.map
+                if cur_map > self.best_map:
+                    # update best-map
+                    self.best_map = cur_map
+                    # save model
+                    print('Saving state, epoch:', self.epoch + 1)
+                    weight_name = '{}_best.pth'.format(self.args.model)
+                    checkpoint_path = os.path.join(path_to_save, weight_name)
+                    torch.save({'model': model_eval.state_dict(),
+                                'mAP': round(self.best_map*100, 1),
+                                'optimizer': self.optimizer.state_dict(),
+                                'epoch': self.epoch,
+                                'args': self.args}, 
+                                checkpoint_path)                      
+
+                # set train mode.
+                model_eval.trainable = True
+                model_eval.train()
+
+        if self.args.distributed:
+            # wait for all processes to synchronize
+            dist.barrier()
+
+
+    def refine_targets(self, targets, min_box_size):
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= min_box_size)
+
+            tgt["boxes"] = boxes[keep]
+            tgt["labels"] = labels[keep]
+        
+        return targets
+
+
+    def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
+        """
+            Deployed for Multi scale trick.
+        """
+        if isinstance(stride, int):
+            max_stride = stride
+        elif isinstance(stride, list):
+            max_stride = max(stride)
+
+        # During training phase, the shape of input image is square.
+        old_img_size = images.shape[-1]
+        new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
+        new_img_size = new_img_size // max_stride * max_stride  # size
+        if new_img_size / old_img_size != 1:
+            # interpolate
+            images = torch.nn.functional.interpolate(
+                                input=images, 
+                                size=new_img_size, 
+                                mode='bilinear', 
+                                align_corners=False)
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
+            boxes = torch.clamp(boxes, 0, old_img_size)
+            # rescale box
+            boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
+            boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= min_box_size)
+
+            tgt["boxes"] = boxes[keep]
+            tgt["labels"] = labels[keep]
+
+        return images, targets, new_img_size
 
-    return best_map

+ 55 - 96
train.py

@@ -21,7 +21,6 @@ from evaluator.build import build_evluator
 from utils.solver.optimizer import build_optimizer
 from utils.solver.lr_scheduler import build_lr_scheduler
 
-from engine import train_one_epoch, val_one_epoch
 # ----------------- Config Components -----------------
 from config import build_dataset_config, build_model_config, build_trans_config
 
@@ -31,6 +30,9 @@ from dataset.build import build_dataset, build_transform
 # ----------------- Model Components -----------------
 from models.detectors import build_model
 
+# ----------------- Train Components -----------------
+from engine import Trainer
+
 
 def parse_args():
     parser = argparse.ArgumentParser(description='YOLO-Tutorial')
@@ -116,7 +118,7 @@ def train():
     print("Setting Arguments.. : ", args)
     print("----------------------------------------------------------")
 
-    # dist
+    # ---------------------------- Build DDP ----------------------------
     world_size = distributed_utils.get_world_size()
     per_gpu_batch = args.batch_size // world_size
     print('World size: {}'.format(world_size))
@@ -124,11 +126,7 @@ def train():
         distributed_utils.init_distributed_mode(args)
         print("git:\n  {}\n".format(distributed_utils.get_sha()))
 
-    # path to save model
-    path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
-    os.makedirs(path_to_save, exist_ok=True)
-
-    # cuda
+    # ---------------------------- Build CUDA ----------------------------
     if args.cuda:
         print('use cuda')
         # cudnn.benchmark = True
@@ -136,157 +134,118 @@ def train():
     else:
         device = torch.device("cpu")
 
-    # Dataset & Model & Trans Config
+    # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
     data_cfg = build_dataset_config(args)
     model_cfg = build_model_config(args)
     trans_cfg = build_trans_config(model_cfg['trans_type'])
 
-    # Transform
+    # ---------------------------- Build Transform ----------------------------
     train_transform, trans_cfg = build_transform(
         args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
     val_transform, _ = build_transform(
         args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
 
-    # Dataset
+    # ---------------------------- Build Dataset & Dataloader ----------------------------
     dataset, dataset_info = build_dataset(args, data_cfg, trans_cfg, train_transform, is_train=True)
+    train_loader = build_dataloader(args, dataset, per_gpu_batch, CollateFunc())
 
-    # Dataloader
-    dataloader = build_dataloader(args, dataset, per_gpu_batch, CollateFunc())
-
-    # Evaluator
+    # ---------------------------- Build Evaluator ----------------------------
     evaluator = build_evluator(args, data_cfg, val_transform, device)
 
-    # Build model
-    model, criterion = build_model(
-        args=args, 
-        model_cfg=model_cfg,
-        device=device,
-        num_classes=dataset_info['num_classes'],
-        trainable=True,
-        )
+    # ---------------------------- Build Model ----------------------------
+    model, criterion = build_model(args, model_cfg, device, dataset_info['num_classes'], True)
     model = model.to(device).train()
+    if args.sybn and args.distributed:
+        print('use SyncBatchNorm ...')
+        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
 
-    # DDP
+    # ---------------------------- Build DDP Model ----------------------------
     model_without_ddp = model
     if args.distributed:
         model = DDP(model, device_ids=[args.gpu])
         model_without_ddp = model.module
 
-    # SyncBatchNorm
-    if args.sybn and args.distributed:
-        print('use SyncBatchNorm ...')
-        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
-
-    # compute FLOPs and Params
+    # ---------------------------- Calcute Params & GFLOPs ----------------------------
     if distributed_utils.is_main_process:
         model_copy = deepcopy(model_without_ddp)
         model_copy.trainable = False
         model_copy.eval()
-        compute_flops(model=model_copy, 
-                         img_size=args.img_size, 
-                         device=device)
+        compute_flops(model=model_copy,
+                      img_size=args.img_size,
+                      device=device)
         del model_copy
     if args.distributed:
         # wait for all processes to synchronize
         dist.barrier()
+        dist.barrier()
 
-    # amp
+    # ---------------------------- Build Grad. Scaler ----------------------------
     scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
 
-    # batch size
-    total_bs = args.batch_size
-    accumulate = max(1, round(64 / total_bs))
+    # ---------------------------- Build Optimizer ----------------------------
+    accumulate = max(1, round(64 / args.batch_size))
     print('Grad_Accumulate: ', accumulate)
-
-    # optimizer
-    model_cfg['weight_decay'] *= total_bs * accumulate / 64
+    model_cfg['weight_decay'] *= args.batch_size * accumulate / 64
     optimizer, start_epoch = build_optimizer(model_cfg, model_without_ddp, model_cfg['lr0'], args.resume)
 
-    # Scheduler
-    total_epochs = args.max_epoch + args.wp_epoch
-    scheduler, lf = build_lr_scheduler(model_cfg, optimizer, total_epochs)
-    scheduler.last_epoch = start_epoch - 1  # do not move
+    # ---------------------------- Build LR Scheduler ----------------------------
+    args.max_epoch += args.wp_epoch
+    lr_scheduler, lf = build_lr_scheduler(model_cfg, optimizer, args.max_epoch)
+    lr_scheduler.last_epoch = start_epoch - 1  # do not move
     if args.resume:
-        scheduler.step()
+        lr_scheduler.step()
 
-    # EMA
+    # ---------------------------- Build Model-EMA ----------------------------
     if args.ema and distributed_utils.get_rank() in [-1, 0]:
         print('Build ModelEMA ...')
-        ema = ModelEMA(model, decay=model_cfg['ema_decay'], tau=model_cfg['ema_tau'], updates=start_epoch * len(dataloader))
+        model_ema = ModelEMA(model, model_cfg['ema_decay'], model_cfg['ema_tau'], start_epoch * len(train_loader))
     else:
-        ema = None
+        model_ema = None
+
+    # ---------------------------- Build Trainer ----------------------------
+    trainer = Trainer(args, device, model_cfg, model_ema, optimizer, lf, lr_scheduler, criterion, scaler)
 
     # start training loop
-    best_map = -1.0
-    last_opt_step = -1
     heavy_eval = False
     optimizer.zero_grad()
     
-    # eval before training
+    # --------------------------------- Main process for training ---------------------------------
+    ## Eval before training
     if args.eval_first and distributed_utils.is_main_process():
         # to check whether the evaluator can work
-        model_eval = ema.ema if ema else model_without_ddp
-        val_one_epoch(
-            args=args, model=model_eval, evaluator=evaluator, optimizer=optimizer,
-            epoch=0, best_map=best_map, path_to_save=path_to_save)
+        model_eval = model_without_ddp
+        trainer.eval_one_epoch(model_eval, evaluator)
 
-    # start to train
-    for epoch in range(start_epoch, total_epochs):
+    ## Satrt Training
+    for epoch in range(start_epoch, args.max_epoch):
         if args.distributed:
-            dataloader.batch_sampler.sampler.set_epoch(epoch)
+            train_loader.batch_sampler.sampler.set_epoch(epoch)
 
         # check second stage
-        if epoch >= (total_epochs - model_cfg['no_aug_epoch'] - 1):
+        if epoch >= (args.max_epoch - model_cfg['no_aug_epoch'] - 1):
             # close mosaic augmentation
-            if dataloader.dataset.mosaic_prob > 0.:
+            if train_loader.dataset.mosaic_prob > 0.:
                 print('close Mosaic Augmentation ...')
-                dataloader.dataset.mosaic_prob = 0.
+                train_loader.dataset.mosaic_prob = 0.
                 heavy_eval = True
             # close mixup augmentation
-            if dataloader.dataset.mixup_prob > 0.:
+            if train_loader.dataset.mixup_prob > 0.:
                 print('close Mixup Augmentation ...')
-                dataloader.dataset.mixup_prob = 0.
+                train_loader.dataset.mixup_prob = 0.
                 heavy_eval = True
 
         # train one epoch
-        last_opt_step = train_one_epoch(
-            epoch=epoch,
-            total_epochs=total_epochs,
-            args=args, 
-            device=device,
-            ema=ema, 
-            model=model,
-            criterion=criterion,
-            cfg=model_cfg, 
-            dataloader=dataloader, 
-            optimizer=optimizer,
-            scheduler=scheduler,
-            lf=lf,
-            scaler=scaler,
-            last_opt_step=last_opt_step)
-
-        # eval
+        trainer.train_one_epoch(model, train_loader)
+
+        # eval one epoch
         if heavy_eval:
-            best_map = val_one_epoch(
-                            args=args, 
-                            model=ema.ema if ema else model_without_ddp, 
-                            evaluator=evaluator,
-                            optimizer=optimizer,
-                            epoch=epoch,
-                            best_map=best_map,
-                            path_to_save=path_to_save)
+            trainer.eval_one_epoch(model_without_ddp, evaluator)
         else:
-            if (epoch % args.eval_epoch) == 0 or (epoch == total_epochs - 1):
-                best_map = val_one_epoch(
-                                args=args, 
-                                model=ema.ema if ema else model_without_ddp, 
-                                evaluator=evaluator,
-                                optimizer=optimizer,
-                                epoch=epoch,
-                                best_map=best_map,
-                                path_to_save=path_to_save)
+            if (epoch % args.eval_epoch) == 0 or (epoch == args.max_epoch - 1):
+                trainer.eval_one_epoch(model_without_ddp, evaluator)
 
     # Empty cache after train loop
+    del trainer
     if args.cuda:
         torch.cuda.empty_cache()