| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332 |
- from copy import deepcopy
- import os
- import time
- import math
- import argparse
- import datetime
- # ---------------- Timm compoments ----------------
- from timm.data.mixup import Mixup
- from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
- # ---------------- Torch compoments ----------------
- import torch
- import torch.backends.cudnn as cudnn
- import torch.distributed as dist
- from torch.nn.parallel import DistributedDataParallel as DDP
- # ---------------- Dataset compoments ----------------
- from data import build_dataset, build_dataloader
- # ---------------- Model compoments ----------------
- from models import build_model
- # ---------------- Utils compoments ----------------
- from utils import distributed_utils
- from utils.ema import ModelEMA
- from utils.misc import setup_seed, print_rank_0, load_model, save_model
- from utils.misc import NativeScalerWithGradNormCount as NativeScaler
- from utils.optimzer import build_optimizer
- from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
- from utils.com_flops_params import FLOPs_and_Params
- # ---------------- Training engine ----------------
- from engine import train_one_epoch, evaluate
- def parse_args():
- parser = argparse.ArgumentParser()
- # Input
- parser.add_argument('--img_size', type=int, default=224,
- help='input image size.')
- parser.add_argument('--img_dim', type=int, default=3,
- help='3 for RGB; 1 for Gray.')
- parser.add_argument('--num_classes', type=int, default=1000,
- help='Number of the classes.')
- # Basic
- parser.add_argument('--seed', type=int, default=42,
- help='random seed.')
- parser.add_argument('--cuda', action='store_true', default=False,
- help='use cuda')
- parser.add_argument('--batch_size', type=int, default=256,
- help='batch size on all GPUs')
- parser.add_argument('--num_workers', type=int, default=4,
- help='number of workers')
- parser.add_argument('--path_to_save', type=str, default='weights/',
- help='path to save trained model.')
- parser.add_argument('--tfboard', action='store_true', default=False,
- help='use tensorboard')
- parser.add_argument('--eval', action='store_true', default=False,
- help='evaluate model.')
- # Epoch
- parser.add_argument('--wp_epoch', type=int, default=20,
- help='warmup epoch for finetune with MAE pretrained')
- parser.add_argument('--start_epoch', type=int, default=0,
- help='start epoch for finetune with MAE pretrained')
- parser.add_argument('--max_epoch', type=int, default=300,
- help='max epoch')
- parser.add_argument('--eval_epoch', type=int, default=10,
- help='max epoch')
- # Dataset
- parser.add_argument('--dataset', type=str, default='cifar10',
- help='dataset name')
- parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
- help='path to dataset folder')
- # Model
- parser.add_argument('-m', '--model', type=str, default='rtcnet_n',
- help='model name')
- parser.add_argument('--resume', default=None, type=str,
- help='keep training')
- parser.add_argument('--ema', action='store_true', default=False,
- help='use ema.')
- parser.add_argument('--drop_path', type=float, default=0.1,
- help='drop_path')
- # Optimizer
- parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
- help='sgd, adam')
- parser.add_argument('-lrs', '--lr_scheduler', type=str, default='step',
- help='cosine, step')
- parser.add_argument('-mt', '--momentum', type=float, default=0.9,
- help='weight decay')
- parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
- help='weight decay')
- parser.add_argument('--batch_base', type=int, default=256,
- help='gradient accumulation')
- parser.add_argument('--base_lr', type=float, default=1e-3,
- help='learning rate for training model')
- parser.add_argument('--min_lr', type=float, default=1e-6,
- help='the final lr')
- parser.add_argument('--grad_accumulate', type=int, default=1,
- help='gradient accumulation')
- parser.add_argument('--max_grad_norm', type=float, default=None,
- help='Clip gradient norm (default: None, no clipping)')
- # Augmentation parameters
- parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
- help='Color jitter factor (enabled only when not using Auto/RandAug)')
- parser.add_argument('--aa', type=str, default=None, metavar='NAME',
- help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
- parser.add_argument('--smoothing', type=float, default=0.1,
- help='Label smoothing (default: 0.1)')
- # Random Erase params
- parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
- help='Random erase prob (default: 0.25)')
- parser.add_argument('--remode', type=str, default='pixel',
- help='Random erase mode (default: "pixel")')
- parser.add_argument('--recount', type=int, default=1,
- help='Random erase count (default: 1)')
- parser.add_argument('--resplit', action='store_true', default=False,
- help='Do not random erase first (clean) augmentation split')
- # Mixup params
- parser.add_argument('--mixup', type=float, default=0,
- help='mixup alpha, mixup enabled if > 0.')
- parser.add_argument('--cutmix', type=float, default=0,
- help='cutmix alpha, cutmix enabled if > 0.')
- parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
- help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
- parser.add_argument('--mixup_prob', type=float, default=1.0,
- help='Probability of performing mixup or cutmix when either/both is enabled')
- parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
- help='Probability of switching to cutmix when both mixup and cutmix enabled')
- parser.add_argument('--mixup_mode', type=str, default='batch',
- help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
- # DDP
- parser.add_argument('-dist', '--distributed', action='store_true', default=False,
- help='distributed training')
- parser.add_argument('--dist_url', default='env://',
- help='url used to set up distributed training')
- parser.add_argument('--world_size', default=1, type=int,
- help='number of distributed processes')
- parser.add_argument('--sybn', action='store_true', default=False,
- help='use sybn.')
- parser.add_argument('--local_rank', default=-1, type=int,
- help='the number of local rank.')
- return parser.parse_args()
-
- def main():
- args = parse_args()
- # set random seed
- setup_seed(args.seed)
- # Path to save model
- path_to_save = os.path.join(args.path_to_save, args.dataset, args.model)
- os.makedirs(path_to_save, exist_ok=True)
- args.output_dir = path_to_save
-
- # ------------------------- Build DDP environment -------------------------
- ## LOCAL_RANK is the global GPU number tag, the value range is [0, world_size - 1].
- ## LOCAL_PROCESS_RANK is the number of the GPU of each machine, not global.
- local_rank = local_process_rank = -1
- if args.distributed:
- distributed_utils.init_distributed_mode(args)
- print("git:\n {}\n".format(distributed_utils.get_sha()))
- try:
- # Multiple Mechine & Multiple GPUs (world size > 8)
- local_rank = torch.distributed.get_rank()
- local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0'))
- except:
- # Single Mechine & Multiple GPUs (world size <= 8)
- local_rank = local_process_rank = torch.distributed.get_rank()
- print_rank_0(args)
- args.world_size = distributed_utils.get_world_size()
- print('World size: {}'.format(distributed_utils.get_world_size()))
- print("LOCAL RANK: ", local_rank)
- print("LOCAL_PROCESS_RANL: ", local_process_rank)
- # ------------------------- Build CUDA -------------------------
- if args.cuda:
- if torch.cuda.is_available():
- cudnn.benchmark = True
- device = torch.device("cuda")
- else:
- print('There is no available GPU.')
- args.cuda = False
- device = torch.device("cpu")
- else:
- device = torch.device("cpu")
- # ------------------------- Build Tensorboard -------------------------
- tblogger = None
- if local_rank <= 0 and args.tfboard:
- print('use tensorboard')
- from torch.utils.tensorboard import SummaryWriter
- time_stamp = time.strftime('%Y-%m-%d_%H:%M:%S',time.localtime(time.time()))
- log_path = os.path.join('log/', args.dataset, time_stamp)
- os.makedirs(log_path, exist_ok=True)
- tblogger = SummaryWriter(log_path)
- # ------------------------- Build Dataset -------------------------
- train_dataset = build_dataset(args, is_train=True)
- val_dataset = build_dataset(args, is_train=False)
- # ------------------------- Build Dataloader -------------------------
- train_dataloader = build_dataloader(args, train_dataset, is_train=True)
- val_dataloader = build_dataloader(args, val_dataset, is_train=False)
- print('=================== Dataset Information ===================')
- print("Dataset: ", args.dataset)
- print('- train dataset size : ', len(train_dataset))
- print('- val dataset size : ', len(val_dataset))
- # ------------------------- Build Model -------------------------
- model = build_model(args)
- model.train().to(device)
- print(model)
- if local_rank <= 0:
- model_copy = deepcopy(model)
- model_copy.eval()
- FLOPs_and_Params(model_copy, args.img_size, args.img_dim, device)
- model_copy.train()
- del model_copy
- if args.distributed:
- # wait for all processes to synchronize
- dist.barrier()
- # ------------------------- Build DDP Model -------------------------
- model_without_ddp = model
- if args.distributed:
- model = DDP(model, device_ids=[args.gpu])
- if args.sybn:
- print('use SyncBatchNorm ...')
- model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
- model_without_ddp = model.module
- # ------------------------- Mixup augmentation config -------------------------
- mixup_fn = None
- mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
- if mixup_active:
- print_rank_0("Mixup is activated!", local_rank)
- mixup_fn = Mixup(mixup_alpha = args.mixup,
- cutmix_alpha = args.cutmix,
- cutmix_minmax = args.cutmix_minmax,
- prob = args.mixup_prob,
- switch_prob = args.mixup_switch_prob,
- mode = args.mixup_mode,
- label_smoothing = args.smoothing,
- num_classes = args.num_classes)
- # ------------------------- Build Optimzier -------------------------
- optimizer = build_optimizer(args, model_without_ddp)
- loss_scaler = NativeScaler()
- # ------------------------- Build Lr Scheduler -------------------------
- lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
- lr_scheduler = build_lr_scheduler(args, optimizer)
- # ------------------------- Build Criterion -------------------------
- if mixup_fn is not None:
- # smoothing is handled with mixup label transform
- criterion = SoftTargetCrossEntropy()
- elif args.smoothing > 0.:
- criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
- else:
- criterion = torch.nn.CrossEntropyLoss()
- load_model(args=args, model_without_ddp=model_without_ddp,
- optimizer=optimizer, lr_scheduler=lr_scheduler, loss_scaler=loss_scaler)
- # ------------------------- Build Model-EMA -------------------------
- if args.ema:
- print("Build model ema for {}".format(args.model))
- updates = args.start_epoch * len(train_dataloader) // args.grad_accumulate
- print("Initialial updates of ModelEMA: {}".format(updates))
- model_ema = ModelEMA(model_without_ddp, ema_decay=0.999, ema_tau=2000., updates=updates)
- else:
- model_ema = None
- # ------------------------- Eval before Train Pipeline -------------------------
- if args.eval:
- print('evaluating ...')
- test_stats = evaluate(val_dataloader, model_without_ddp, device, local_rank)
- print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' %
- (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True)
- return
- # ------------------------- Training Pipeline -------------------------
- start_time = time.time()
- max_accuracy = -1.0
- print_rank_0("=============== Start training for {} epochs ===============".format(args.max_epoch), local_rank)
- for epoch in range(args.start_epoch, args.max_epoch):
- if args.distributed:
- train_dataloader.batch_sampler.sampler.set_epoch(epoch)
- # train one epoch
- train_one_epoch(args, device, model, model_ema, train_dataloader, optimizer, epoch,
- lr_scheduler_warmup, loss_scaler, criterion, local_rank, tblogger, mixup_fn)
- # LR scheduler
- if (epoch + 1) > args.wp_epoch:
- lr_scheduler.step()
- # Evaluate
- if local_rank <= 0:
- model_eval = model_ema.ema if model_ema is not None else model_without_ddp
- if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch):
- print_rank_0("Evaluating ...")
- test_stats = evaluate(val_dataloader, model_eval, device, local_rank)
- print_rank_0(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%", local_rank)
- max_accuracy = max(max_accuracy, test_stats["acc1"])
- print_rank_0(f'Max accuracy: {max_accuracy:.2f}%', local_rank)
- # Save model
- print('- saving the model after {} epochs ...'.format(epoch))
- save_model(args=args, model=model_eval, model_without_ddp=model_eval,
- optimizer=optimizer, lr_scheduler=lr_scheduler, loss_scaler=loss_scaler, epoch=epoch, acc1=max_accuracy)
- if args.distributed:
- dist.barrier()
- if tblogger is not None:
- tblogger.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
- tblogger.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
- tblogger.add_scalar('perf/test_loss', test_stats['loss'], epoch)
- if args.distributed:
- dist.barrier()
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print('Training time {}'.format(total_time_str))
- if __name__ == "__main__":
- main()
|