| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- import os
- import time
- import argparse
- import datetime
- # ---------------- Torch compoments ----------------
- import torch
- import torch.backends.cudnn as cudnn
- # ---------------- Dataset compoments ----------------
- from data import build_dataset, build_dataloader
- # ---------------- Model compoments ----------------
- from models import build_model
- # ---------------- Utils compoments ----------------
- from utils.misc import setup_seed, load_model, save_model
- from utils.optimizer import build_optimizer
- from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
- # ---------------- Training engine ----------------
- from engine_finetune import train_one_epoch, evaluate
- def parse_args():
- parser = argparse.ArgumentParser()
- # Input
- parser.add_argument('--img_dim', type=int, default=3,
- help='3 for RGB; 1 for Gray.')
- parser.add_argument('--patch_size', type=int, default=16,
- help='patch_size.')
- # Basic
- parser.add_argument('--seed', type=int, default=42,
- help='random seed.')
- parser.add_argument('--cuda', action='store_true', default=False,
- help='use cuda')
- parser.add_argument('--batch_size', type=int, default=256,
- help='batch size on all GPUs')
- parser.add_argument('--num_workers', type=int, default=4,
- help='number of workers')
- parser.add_argument('--path_to_save', type=str, default='weights/',
- help='path to save trained model.')
- parser.add_argument('--eval', action='store_true', default=False,
- help='evaluate model.')
- # Epoch
- parser.add_argument('--wp_epoch', type=int, default=5,
- help='warmup epoch')
- parser.add_argument('--start_epoch', type=int, default=0,
- help='start epoch')
- parser.add_argument('--max_epoch', type=int, default=50,
- help='max epoch')
- parser.add_argument('--eval_epoch', type=int, default=5,
- help='max epoch')
- # Dataset
- parser.add_argument('--dataset', type=str, default='cifar10',
- help='dataset name')
- parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
- help='path to dataset folder')
- parser.add_argument('--num_classes', type=int, default=None,
- help='number of classes.')
- # Model
- parser.add_argument('-m', '--model', type=str, default='vit_t',
- help='model name')
- parser.add_argument('--pretrained', default=None, type=str,
- help='load pretrained weight.')
- parser.add_argument('--resume', default=None, type=str,
- help='keep training')
- parser.add_argument('--drop_path', type=float, default=0.1,
- help='drop_path')
- # Optimizer
- parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
- help='sgd, adam')
- parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
- help='weight decay')
- parser.add_argument('--base_lr', type=float, default=0.001,
- help='learning rate for training model')
- parser.add_argument('--min_lr', type=float, default=0,
- help='the final lr')
- # Lr scheduler
- parser.add_argument('-lrs', '--lr_scheduler', type=str, default='cosine',
- help='step, cosine')
- return parser.parse_args()
-
- def main():
- args = parse_args()
- # set random seed
- setup_seed(args.seed)
- # Path to save model
- path_to_save = os.path.join(args.path_to_save, args.dataset, "finetune", args.model)
- os.makedirs(path_to_save, exist_ok=True)
- args.output_dir = path_to_save
- # ------------------------- Build CUDA -------------------------
- if args.cuda:
- if torch.cuda.is_available():
- cudnn.benchmark = True
- device = torch.device("cuda")
- else:
- print('There is no available GPU.')
- args.cuda = False
- device = torch.device("cpu")
- else:
- device = torch.device("cpu")
- # ------------------------- Build Dataset -------------------------
- train_dataset = build_dataset(args, is_train=True)
- val_dataset = build_dataset(args, is_train=False)
- # ------------------------- Build Dataloader -------------------------
- train_dataloader = build_dataloader(args, train_dataset, is_train=True)
- val_dataloader = build_dataloader(args, val_dataset, is_train=False)
- print('=================== Dataset Information ===================')
- print('Train dataset size : ', len(train_dataset))
- print('Val dataset size : ', len(val_dataset))
- # ------------------------- Build Model -------------------------
- model = build_model(args, model_type='cls')
- model.train().to(device)
- print(model)
- # ------------------------- Build Optimzier -------------------------
- optimizer = build_optimizer(args, model)
- # ------------------------- Build Lr Scheduler -------------------------
- lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
- lr_scheduler = build_lr_scheduler(args, optimizer)
- # ------------------------- Build Criterion -------------------------
- criterion = torch.nn.CrossEntropyLoss()
- load_model(args, model, optimizer, lr_scheduler)
- # ------------------------- Eval before Train Pipeline -------------------------
- if args.eval:
- print('evaluating ...')
- test_stats = evaluate(val_dataloader, model, device)
- print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' %
- (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True)
- return
- # ------------------------- Training Pipeline -------------------------
- start_time = time.time()
- max_accuracy = -1.0
- print("=============== Start training for {} epochs ===============".format(args.max_epoch))
- for epoch in range(args.start_epoch, args.max_epoch):
- # Train one epoch
- train_one_epoch(args, device, model, train_dataloader, optimizer,
- epoch, lr_scheduler_warmup, criterion)
- # LR scheduler
- if (epoch + 1) > args.wp_epoch:
- lr_scheduler.step()
- # Evaluate
- if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch):
- test_stats = evaluate(val_dataloader, model, device)
- print(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%")
- max_accuracy = max(max_accuracy, test_stats["acc1"])
- print(f'Max accuracy: {max_accuracy:.2f}%')
- # Save model
- print('- saving the model after {} epochs ...'.format(epoch))
- save_model(args, epoch, model, optimizer, lr_scheduler, acc1=max_accuracy)
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print('Training time {}'.format(total_time_str))
- if __name__ == "__main__":
- main()
|