main.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import os
  2. import time
  3. import matplotlib.pyplot as plt
  4. import argparse
  5. import datetime
  6. # ---------------- Torch compoments ----------------
  7. import torch
  8. import torch.backends.cudnn as cudnn
  9. # ---------------- Dataset compoments ----------------
  10. from data import build_dataset, build_dataloader
  11. # ---------------- Model compoments ----------------
  12. from models import build_model
  13. # ---------------- Utils compoments ----------------
  14. from utils.misc import setup_seed, load_model, save_model
  15. from utils.optimzer import build_optimizer
  16. from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
  17. # ---------------- Training engine ----------------
  18. from engine import train_one_epoch, evaluate
  19. def parse_args():
  20. parser = argparse.ArgumentParser()
  21. # Basic
  22. parser.add_argument('--seed', type=int, default=42,
  23. help='random seed.')
  24. parser.add_argument('--cuda', action='store_true', default=False,
  25. help='use cuda')
  26. parser.add_argument('--batch_size', type=int, default=256,
  27. help='batch size on all GPUs')
  28. parser.add_argument('--num_workers', type=int, default=4,
  29. help='number of workers')
  30. parser.add_argument('--path_to_save', type=str, default='weights/',
  31. help='path to save trained model.')
  32. parser.add_argument('--eval', action='store_true', default=False,
  33. help='evaluate model.')
  34. # Epoch
  35. parser.add_argument('--wp_epoch', type=int, default=1,
  36. help='warmup epoch for finetune with MAE pretrained')
  37. parser.add_argument('--start_epoch', type=int, default=0,
  38. help='start epoch for finetune with MAE pretrained')
  39. parser.add_argument('--max_epoch', type=int, default=50,
  40. help='max epoch')
  41. parser.add_argument('--eval_epoch', type=int, default=5,
  42. help='max epoch')
  43. # Dataset
  44. parser.add_argument('--dataset', type=str, default='cifar10',
  45. help='dataset name')
  46. parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
  47. help='path to dataset folder')
  48. parser.add_argument('--img_dim', type=int, default=3,
  49. help='input image dimension')
  50. parser.add_argument('--num_classes', type=int, default=1000,
  51. help='number of the classes')
  52. # Model
  53. parser.add_argument('-m', '--model', type=str, default='mlp4',
  54. help='model name')
  55. parser.add_argument('--resume', default=None, type=str,
  56. help='keep training')
  57. # Optimizer
  58. parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
  59. help='sgd, adam')
  60. parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
  61. help='weight decay')
  62. parser.add_argument('--base_lr', type=float, default=1e-3,
  63. help='learning rate for training model')
  64. parser.add_argument('--min_lr', type=float, default=1e-6,
  65. help='the final lr')
  66. # Lr scheduler
  67. parser.add_argument('-lrs', '--lr_scheduler', type=str, default='step',
  68. help='lr scheduler: cosine, step')
  69. return parser.parse_args()
  70. def main():
  71. args = parse_args()
  72. print(args)
  73. # set random seed
  74. setup_seed(args.seed)
  75. # Path to save model
  76. path_to_save = os.path.join(args.path_to_save, args.dataset, args.model)
  77. os.makedirs(path_to_save, exist_ok=True)
  78. args.output_dir = path_to_save
  79. # ------------------------- Build CUDA -------------------------
  80. if args.cuda:
  81. if torch.cuda.is_available():
  82. cudnn.benchmark = True
  83. device = torch.device("cuda")
  84. else:
  85. print('There is no available GPU.')
  86. args.cuda = False
  87. device = torch.device("cpu")
  88. else:
  89. device = torch.device("cpu")
  90. # ------------------------- Build Dataset -------------------------
  91. train_dataset = build_dataset(args, is_train=True)
  92. val_dataset = build_dataset(args, is_train=False)
  93. # ------------------------- Build Dataloader -------------------------
  94. train_dataloader = build_dataloader(args, train_dataset, is_train=True)
  95. val_dataloader = build_dataloader(args, val_dataset, is_train=False)
  96. print('=================== Dataset Information ===================')
  97. print("Dataset: ", args.dataset)
  98. print('- train dataset size : ', len(train_dataset))
  99. print('- val dataset size : ', len(val_dataset))
  100. # ------------------------- Build Model -------------------------
  101. model = build_model(args)
  102. model.train().to(device)
  103. print(model)
  104. # ------------------------- Build Criterion -------------------------
  105. criterion = torch.nn.CrossEntropyLoss()
  106. # ------------------------- Build Optimzier -------------------------
  107. optimizer = build_optimizer(args, model)
  108. # ------------------------- Build Lr Scheduler -------------------------
  109. lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
  110. lr_scheduler = build_lr_scheduler(args, optimizer)
  111. # ------------------------- Build Criterion -------------------------
  112. load_model(args, model, optimizer, lr_scheduler)
  113. # ------------------------- Eval before Train Pipeline -------------------------
  114. if args.eval:
  115. print('evaluating ...')
  116. test_stats = evaluate(val_dataloader, model, device)
  117. print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' %
  118. (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True)
  119. return
  120. # ------------------------- Training Pipeline -------------------------
  121. start_time = time.time()
  122. max_accuracy = -1.0
  123. print("=============== Start training for {} epochs ===============".format(args.max_epoch))
  124. train_loss_logs = []
  125. valid_loss_logs = []
  126. valid_acc1_logs = []
  127. for epoch in range(args.start_epoch, args.max_epoch):
  128. # train one epoch
  129. train_stats = train_one_epoch(args, device, model, train_dataloader, optimizer,
  130. epoch, lr_scheduler_warmup, criterion)
  131. # LR scheduler
  132. if (epoch + 1) > args.wp_epoch:
  133. lr_scheduler.step()
  134. train_loss_logs.append((epoch, train_stats["loss"]))
  135. # Evaluate
  136. if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch):
  137. print("Evaluating ...")
  138. test_stats = evaluate(val_dataloader, model, device)
  139. print(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%")
  140. max_accuracy = max(max_accuracy, test_stats["acc1"])
  141. print(f'Max accuracy: {max_accuracy:.2f}%')
  142. # Save model
  143. print('- saving the model after {} epochs ...'.format(epoch))
  144. save_model(args, epoch, model, optimizer, lr_scheduler, test_stats["acc1"])
  145. valid_acc1_logs.append((epoch, test_stats["acc1"]))
  146. valid_loss_logs.append((epoch, test_stats["loss"]))
  147. total_time = time.time() - start_time
  148. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  149. print('Training time {}'.format(total_time_str))
  150. # --------------- Plot log curve ---------------
  151. ## Training loss
  152. epochs = [sample[0] for sample in train_loss_logs]
  153. tloss = [sample[1] for sample in train_loss_logs]
  154. plt.plot(epochs, tloss, c='r', label='training loss')
  155. plt.xlabel('epoch')
  156. plt.ylabel('loss')
  157. plt.title('Training & Validation loss curve')
  158. ## Valid loss
  159. epochs = [sample[0] for sample in valid_loss_logs]
  160. vloss = [sample[1] for sample in valid_loss_logs]
  161. plt.plot(epochs, vloss, c='b', label='validation loss')
  162. plt.show()
  163. ## Valid acc1
  164. epochs = [sample[0] for sample in valid_acc1_logs]
  165. acc1 = [sample[1] for sample in valid_acc1_logs]
  166. plt.plot(epochs, acc1, label='validation loss')
  167. plt.xlabel('epoch')
  168. plt.ylabel('top1 accuracy')
  169. plt.title('Validation top-1 accuracy curve')
  170. plt.show()
  171. if __name__ == "__main__":
  172. main()