main_finetune.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import os
  2. import time
  3. import argparse
  4. import datetime
  5. # ---------------- Torch compoments ----------------
  6. import torch
  7. import torch.backends.cudnn as cudnn
  8. # ---------------- Dataset compoments ----------------
  9. from data import build_dataset, build_dataloader
  10. # ---------------- Model compoments ----------------
  11. from models import build_model
  12. # ---------------- Utils compoments ----------------
  13. from utils.misc import setup_seed, load_model, save_model
  14. from utils.optimizer import build_optimizer
  15. from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
  16. # ---------------- Training engine ----------------
  17. from engine_finetune import train_one_epoch, evaluate
  18. def parse_args():
  19. parser = argparse.ArgumentParser()
  20. # Input
  21. parser.add_argument('--img_dim', type=int, default=3,
  22. help='3 for RGB; 1 for Gray.')
  23. parser.add_argument('--patch_size', type=int, default=16,
  24. help='patch_size.')
  25. # Basic
  26. parser.add_argument('--seed', type=int, default=42,
  27. help='random seed.')
  28. parser.add_argument('--cuda', action='store_true', default=False,
  29. help='use cuda')
  30. parser.add_argument('--batch_size', type=int, default=256,
  31. help='batch size on all GPUs')
  32. parser.add_argument('--num_workers', type=int, default=4,
  33. help='number of workers')
  34. parser.add_argument('--path_to_save', type=str, default='weights/',
  35. help='path to save trained model.')
  36. parser.add_argument('--eval', action='store_true', default=False,
  37. help='evaluate model.')
  38. # Epoch
  39. parser.add_argument('--wp_epoch', type=int, default=5,
  40. help='warmup epoch')
  41. parser.add_argument('--start_epoch', type=int, default=0,
  42. help='start epoch')
  43. parser.add_argument('--max_epoch', type=int, default=50,
  44. help='max epoch')
  45. parser.add_argument('--eval_epoch', type=int, default=5,
  46. help='max epoch')
  47. # Dataset
  48. parser.add_argument('--dataset', type=str, default='cifar10',
  49. help='dataset name')
  50. parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
  51. help='path to dataset folder')
  52. parser.add_argument('--num_classes', type=int, default=None,
  53. help='number of classes.')
  54. # Model
  55. parser.add_argument('-m', '--model', type=str, default='vit_t',
  56. help='model name')
  57. parser.add_argument('--pretrained', default=None, type=str,
  58. help='load pretrained weight.')
  59. parser.add_argument('--resume', default=None, type=str,
  60. help='keep training')
  61. parser.add_argument('--drop_path', type=float, default=0.1,
  62. help='drop_path')
  63. # Optimizer
  64. parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
  65. help='sgd, adam')
  66. parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
  67. help='weight decay')
  68. parser.add_argument('--base_lr', type=float, default=0.001,
  69. help='learning rate for training model')
  70. parser.add_argument('--min_lr', type=float, default=0,
  71. help='the final lr')
  72. # Lr scheduler
  73. parser.add_argument('-lrs', '--lr_scheduler', type=str, default='cosine',
  74. help='step, cosine')
  75. return parser.parse_args()
  76. def main():
  77. args = parse_args()
  78. # set random seed
  79. setup_seed(args.seed)
  80. # Path to save model
  81. path_to_save = os.path.join(args.path_to_save, args.dataset, "finetune", args.model)
  82. os.makedirs(path_to_save, exist_ok=True)
  83. args.output_dir = path_to_save
  84. # ------------------------- Build CUDA -------------------------
  85. if args.cuda:
  86. if torch.cuda.is_available():
  87. cudnn.benchmark = True
  88. device = torch.device("cuda")
  89. else:
  90. print('There is no available GPU.')
  91. args.cuda = False
  92. device = torch.device("cpu")
  93. else:
  94. device = torch.device("cpu")
  95. # ------------------------- Build Dataset -------------------------
  96. train_dataset = build_dataset(args, is_train=True)
  97. val_dataset = build_dataset(args, is_train=False)
  98. # ------------------------- Build Dataloader -------------------------
  99. train_dataloader = build_dataloader(args, train_dataset, is_train=True)
  100. val_dataloader = build_dataloader(args, val_dataset, is_train=False)
  101. print('=================== Dataset Information ===================')
  102. print('Train dataset size : ', len(train_dataset))
  103. print('Val dataset size : ', len(val_dataset))
  104. # ------------------------- Build Model -------------------------
  105. model = build_model(args, model_type='cls')
  106. model.train().to(device)
  107. print(model)
  108. # ------------------------- Build Optimzier -------------------------
  109. optimizer = build_optimizer(args, model)
  110. # ------------------------- Build Lr Scheduler -------------------------
  111. lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
  112. lr_scheduler = build_lr_scheduler(args, optimizer)
  113. # ------------------------- Build Criterion -------------------------
  114. criterion = torch.nn.CrossEntropyLoss()
  115. load_model(args, model, optimizer, lr_scheduler)
  116. # ------------------------- Eval before Train Pipeline -------------------------
  117. if args.eval:
  118. print('evaluating ...')
  119. test_stats = evaluate(val_dataloader, model, device)
  120. print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' %
  121. (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True)
  122. return
  123. # ------------------------- Training Pipeline -------------------------
  124. start_time = time.time()
  125. max_accuracy = -1.0
  126. print("=============== Start training for {} epochs ===============".format(args.max_epoch))
  127. for epoch in range(args.start_epoch, args.max_epoch):
  128. # Train one epoch
  129. train_one_epoch(args, device, model, train_dataloader, optimizer,
  130. epoch, lr_scheduler_warmup, criterion)
  131. # LR scheduler
  132. if (epoch + 1) > args.wp_epoch:
  133. lr_scheduler.step()
  134. # Evaluate
  135. if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch):
  136. test_stats = evaluate(val_dataloader, model, device)
  137. print(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%")
  138. max_accuracy = max(max_accuracy, test_stats["acc1"])
  139. print(f'Max accuracy: {max_accuracy:.2f}%')
  140. # Save model
  141. print('- saving the model after {} epochs ...'.format(epoch))
  142. save_model(args, epoch, model, optimizer, lr_scheduler, acc1=max_accuracy)
  143. total_time = time.time() - start_time
  144. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  145. print('Training time {}'.format(total_time_str))
  146. if __name__ == "__main__":
  147. main()