import os import time import argparse import datetime # ---------------- Torch compoments ---------------- import torch import torch.backends.cudnn as cudnn # ---------------- Dataset compoments ---------------- from data import build_dataset, build_dataloader # ---------------- Model compoments ---------------- from models import build_model # ---------------- Utils compoments ---------------- from utils.misc import setup_seed, load_model, save_model from utils.optimizer import build_optimizer from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler # ---------------- Training engine ---------------- from engine_finetune import train_one_epoch, evaluate def parse_args(): parser = argparse.ArgumentParser() # Input parser.add_argument('--img_dim', type=int, default=3, help='3 for RGB; 1 for Gray.') parser.add_argument('--patch_size', type=int, default=16, help='patch_size.') # Basic parser.add_argument('--seed', type=int, default=42, help='random seed.') parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') parser.add_argument('--batch_size', type=int, default=256, help='batch size on all GPUs') parser.add_argument('--num_workers', type=int, default=4, help='number of workers') parser.add_argument('--path_to_save', type=str, default='weights/', help='path to save trained model.') parser.add_argument('--eval', action='store_true', default=False, help='evaluate model.') # Epoch parser.add_argument('--wp_epoch', type=int, default=5, help='warmup epoch') parser.add_argument('--start_epoch', type=int, default=0, help='start epoch') parser.add_argument('--max_epoch', type=int, default=50, help='max epoch') parser.add_argument('--eval_epoch', type=int, default=5, help='max epoch') # Dataset parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name') parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset', help='path to dataset folder') parser.add_argument('--num_classes', type=int, default=None, help='number of classes.') # Model parser.add_argument('-m', '--model', type=str, default='vit_t', help='model name') parser.add_argument('--pretrained', default=None, type=str, help='load pretrained weight.') parser.add_argument('--resume', default=None, type=str, help='keep training') parser.add_argument('--drop_path', type=float, default=0.1, help='drop_path') # Optimizer parser.add_argument('-opt', '--optimizer', type=str, default='adamw', help='sgd, adam') parser.add_argument('-wd', '--weight_decay', type=float, default=0.05, help='weight decay') parser.add_argument('--base_lr', type=float, default=0.001, help='learning rate for training model') parser.add_argument('--min_lr', type=float, default=0, help='the final lr') # Lr scheduler parser.add_argument('-lrs', '--lr_scheduler', type=str, default='cosine', help='step, cosine') return parser.parse_args() def main(): args = parse_args() # set random seed setup_seed(args.seed) # Path to save model path_to_save = os.path.join(args.path_to_save, args.dataset, "finetune", args.model) os.makedirs(path_to_save, exist_ok=True) args.output_dir = path_to_save # ------------------------- Build CUDA ------------------------- if args.cuda: if torch.cuda.is_available(): cudnn.benchmark = True device = torch.device("cuda") else: print('There is no available GPU.') args.cuda = False device = torch.device("cpu") else: device = torch.device("cpu") # ------------------------- Build Dataset ------------------------- train_dataset = build_dataset(args, is_train=True) val_dataset = build_dataset(args, is_train=False) # ------------------------- Build Dataloader ------------------------- train_dataloader = build_dataloader(args, train_dataset, is_train=True) val_dataloader = build_dataloader(args, val_dataset, is_train=False) print('=================== Dataset Information ===================') print('Train dataset size : ', len(train_dataset)) print('Val dataset size : ', len(val_dataset)) # ------------------------- Build Model ------------------------- model = build_model(args, model_type='cls') model.train().to(device) print(model) # ------------------------- Build Optimzier ------------------------- optimizer = build_optimizer(args, model) # ------------------------- Build Lr Scheduler ------------------------- lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader)) lr_scheduler = build_lr_scheduler(args, optimizer) # ------------------------- Build Criterion ------------------------- criterion = torch.nn.CrossEntropyLoss() load_model(args, model, optimizer, lr_scheduler) # ------------------------- Eval before Train Pipeline ------------------------- if args.eval: print('evaluating ...') test_stats = evaluate(val_dataloader, model, device) print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' % (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True) return # ------------------------- Training Pipeline ------------------------- start_time = time.time() max_accuracy = -1.0 print("=============== Start training for {} epochs ===============".format(args.max_epoch)) for epoch in range(args.start_epoch, args.max_epoch): # Train one epoch train_one_epoch(args, device, model, train_dataloader, optimizer, epoch, lr_scheduler_warmup, criterion) # LR scheduler if (epoch + 1) > args.wp_epoch: lr_scheduler.step() # Evaluate if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch): test_stats = evaluate(val_dataloader, model, device) print(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') # Save model print('- saving the model after {} epochs ...'.format(epoch)) save_model(args, epoch, model, optimizer, lr_scheduler, acc1=max_accuracy) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == "__main__": main()