from copy import deepcopy import os import time import math import argparse import datetime # ---------------- Timm compoments ---------------- from timm.data.mixup import Mixup from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy # ---------------- Torch compoments ---------------- import torch import torch.backends.cudnn as cudnn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP # ---------------- Dataset compoments ---------------- from data import build_dataset, build_dataloader # ---------------- Model compoments ---------------- from models import build_model # ---------------- Utils compoments ---------------- from utils import distributed_utils from utils.ema import ModelEMA from utils.misc import setup_seed, print_rank_0, load_model, save_model from utils.misc import NativeScalerWithGradNormCount as NativeScaler from utils.optimzer import build_optimizer from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler from utils.com_flops_params import FLOPs_and_Params # ---------------- Training engine ---------------- from engine import train_one_epoch, evaluate def parse_args(): parser = argparse.ArgumentParser() # Input parser.add_argument('--img_size', type=int, default=224, help='input image size.') parser.add_argument('--img_dim', type=int, default=3, help='3 for RGB; 1 for Gray.') parser.add_argument('--num_classes', type=int, default=1000, help='Number of the classes.') # Basic parser.add_argument('--seed', type=int, default=42, help='random seed.') parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') parser.add_argument('--batch_size', type=int, default=256, help='batch size on all GPUs') parser.add_argument('--num_workers', type=int, default=4, help='number of workers') parser.add_argument('--path_to_save', type=str, default='weights/', help='path to save trained model.') parser.add_argument('--tfboard', action='store_true', default=False, help='use tensorboard') parser.add_argument('--eval', action='store_true', default=False, help='evaluate model.') # Epoch parser.add_argument('--wp_epoch', type=int, default=20, help='warmup epoch for finetune with MAE pretrained') parser.add_argument('--start_epoch', type=int, default=0, help='start epoch for finetune with MAE pretrained') parser.add_argument('--max_epoch', type=int, default=300, help='max epoch') parser.add_argument('--eval_epoch', type=int, default=10, help='max epoch') # Dataset parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name') parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset', help='path to dataset folder') # Model parser.add_argument('-m', '--model', type=str, default='rtcnet_n', help='model name') parser.add_argument('--resume', default=None, type=str, help='keep training') parser.add_argument('--ema', action='store_true', default=False, help='use ema.') parser.add_argument('--drop_path', type=float, default=0.1, help='drop_path') # Optimizer parser.add_argument('-opt', '--optimizer', type=str, default='adamw', help='sgd, adam') parser.add_argument('-lrs', '--lr_scheduler', type=str, default='step', help='cosine, step') parser.add_argument('-mt', '--momentum', type=float, default=0.9, help='weight decay') parser.add_argument('-wd', '--weight_decay', type=float, default=0.05, help='weight decay') parser.add_argument('--batch_base', type=int, default=256, help='gradient accumulation') parser.add_argument('--base_lr', type=float, default=1e-3, help='learning rate for training model') parser.add_argument('--min_lr', type=float, default=1e-6, help='the final lr') parser.add_argument('--grad_accumulate', type=int, default=1, help='gradient accumulation') parser.add_argument('--max_grad_norm', type=float, default=None, help='Clip gradient norm (default: None, no clipping)') # Augmentation parameters parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Color jitter factor (enabled only when not using Auto/RandAug)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') # Random Erase params parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob (default: 0.25)') parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') # Mixup params parser.add_argument('--mixup', type=float, default=0, help='mixup alpha, mixup enabled if > 0.') parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha, cutmix enabled if > 0.') parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix when either/both is enabled') parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix when both mixup and cutmix enabled') parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') # DDP parser.add_argument('-dist', '--distributed', action='store_true', default=False, help='distributed training') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--sybn', action='store_true', default=False, help='use sybn.') parser.add_argument('--local_rank', default=-1, type=int, help='the number of local rank.') return parser.parse_args() def main(): args = parse_args() # set random seed setup_seed(args.seed) # Path to save model path_to_save = os.path.join(args.path_to_save, args.dataset, args.model) os.makedirs(path_to_save, exist_ok=True) args.output_dir = path_to_save # ------------------------- Build DDP environment ------------------------- ## LOCAL_RANK is the global GPU number tag, the value range is [0, world_size - 1]. ## LOCAL_PROCESS_RANK is the number of the GPU of each machine, not global. local_rank = local_process_rank = -1 if args.distributed: distributed_utils.init_distributed_mode(args) print("git:\n {}\n".format(distributed_utils.get_sha())) try: # Multiple Mechine & Multiple GPUs (world size > 8) local_rank = torch.distributed.get_rank() local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0')) except: # Single Mechine & Multiple GPUs (world size <= 8) local_rank = local_process_rank = torch.distributed.get_rank() print_rank_0(args) args.world_size = distributed_utils.get_world_size() print('World size: {}'.format(distributed_utils.get_world_size())) print("LOCAL RANK: ", local_rank) print("LOCAL_PROCESS_RANL: ", local_process_rank) # ------------------------- Build CUDA ------------------------- if args.cuda: if torch.cuda.is_available(): cudnn.benchmark = True device = torch.device("cuda") else: print('There is no available GPU.') args.cuda = False device = torch.device("cpu") else: device = torch.device("cpu") # ------------------------- Build Tensorboard ------------------------- tblogger = None if local_rank <= 0 and args.tfboard: print('use tensorboard') from torch.utils.tensorboard import SummaryWriter time_stamp = time.strftime('%Y-%m-%d_%H:%M:%S',time.localtime(time.time())) log_path = os.path.join('log/', args.dataset, time_stamp) os.makedirs(log_path, exist_ok=True) tblogger = SummaryWriter(log_path) # ------------------------- Build Dataset ------------------------- train_dataset = build_dataset(args, is_train=True) val_dataset = build_dataset(args, is_train=False) # ------------------------- Build Dataloader ------------------------- train_dataloader = build_dataloader(args, train_dataset, is_train=True) val_dataloader = build_dataloader(args, val_dataset, is_train=False) print('=================== Dataset Information ===================') print("Dataset: ", args.dataset) print('- train dataset size : ', len(train_dataset)) print('- val dataset size : ', len(val_dataset)) # ------------------------- Build Model ------------------------- model = build_model(args) model.train().to(device) print(model) if local_rank <= 0: model_copy = deepcopy(model) model_copy.eval() FLOPs_and_Params(model_copy, args.img_size, args.img_dim, device) model_copy.train() del model_copy if args.distributed: # wait for all processes to synchronize dist.barrier() # ------------------------- Build DDP Model ------------------------- model_without_ddp = model if args.distributed: model = DDP(model, device_ids=[args.gpu]) if args.sybn: print('use SyncBatchNorm ...') model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model_without_ddp = model.module # ------------------------- Mixup augmentation config ------------------------- mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: print_rank_0("Mixup is activated!", local_rank) mixup_fn = Mixup(mixup_alpha = args.mixup, cutmix_alpha = args.cutmix, cutmix_minmax = args.cutmix_minmax, prob = args.mixup_prob, switch_prob = args.mixup_switch_prob, mode = args.mixup_mode, label_smoothing = args.smoothing, num_classes = args.num_classes) # ------------------------- Build Optimzier ------------------------- optimizer = build_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() # ------------------------- Build Lr Scheduler ------------------------- lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader)) lr_scheduler = build_lr_scheduler(args, optimizer) # ------------------------- Build Criterion ------------------------- if mixup_fn is not None: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing > 0.: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_scaler=loss_scaler) # ------------------------- Build Model-EMA ------------------------- if args.ema: print("Build model ema for {}".format(args.model)) updates = args.start_epoch * len(train_dataloader) // args.grad_accumulate print("Initialial updates of ModelEMA: {}".format(updates)) model_ema = ModelEMA(model_without_ddp, ema_decay=0.999, ema_tau=2000., updates=updates) else: model_ema = None # ------------------------- Eval before Train Pipeline ------------------------- if args.eval: print('evaluating ...') test_stats = evaluate(val_dataloader, model_without_ddp, device, local_rank) print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' % (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True) return # ------------------------- Training Pipeline ------------------------- start_time = time.time() max_accuracy = -1.0 print_rank_0("=============== Start training for {} epochs ===============".format(args.max_epoch), local_rank) for epoch in range(args.start_epoch, args.max_epoch): if args.distributed: train_dataloader.batch_sampler.sampler.set_epoch(epoch) # train one epoch train_one_epoch(args, device, model, model_ema, train_dataloader, optimizer, epoch, lr_scheduler_warmup, loss_scaler, criterion, local_rank, tblogger, mixup_fn) # LR scheduler if (epoch + 1) > args.wp_epoch: lr_scheduler.step() # Evaluate if local_rank <= 0: model_eval = model_ema.ema if model_ema is not None else model_without_ddp if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch): print_rank_0("Evaluating ...") test_stats = evaluate(val_dataloader, model_eval, device, local_rank) print_rank_0(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%", local_rank) max_accuracy = max(max_accuracy, test_stats["acc1"]) print_rank_0(f'Max accuracy: {max_accuracy:.2f}%', local_rank) # Save model print('- saving the model after {} epochs ...'.format(epoch)) save_model(args=args, model=model_eval, model_without_ddp=model_eval, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_scaler=loss_scaler, epoch=epoch, acc1=max_accuracy) if args.distributed: dist.barrier() if tblogger is not None: tblogger.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) tblogger.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) tblogger.add_scalar('perf/test_loss', test_stats['loss'], epoch) if args.distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == "__main__": main()