main_pretrain.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import os
  2. import cv2
  3. import time
  4. import datetime
  5. import argparse
  6. import numpy as np
  7. # ---------------- Torch compoments ----------------
  8. import torch
  9. import torch.backends.cudnn as cudnn
  10. # ---------------- Dataset compoments ----------------
  11. from data import build_dataset, build_dataloader
  12. from models import build_model
  13. # ---------------- Utils compoments ----------------
  14. from utils.misc import setup_seed
  15. from utils.misc import load_model, save_model, unpatchify
  16. from utils.optimizer import build_optimizer
  17. from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
  18. # ---------------- Training engine ----------------
  19. from engine_pretrain import train_one_epoch
  20. def parse_args():
  21. parser = argparse.ArgumentParser()
  22. # Basic
  23. parser.add_argument('--seed', type=int, default=42,
  24. help='random seed.')
  25. parser.add_argument('--cuda', action='store_true', default=False,
  26. help='use cuda')
  27. parser.add_argument('--batch_size', type=int, default=256,
  28. help='batch size on all GPUs')
  29. parser.add_argument('--num_workers', type=int, default=4,
  30. help='number of workers')
  31. parser.add_argument('--path_to_save', type=str, default='weights/',
  32. help='path to save trained model.')
  33. parser.add_argument('--eval', action='store_true', default=False,
  34. help='evaluate model.')
  35. # Epoch
  36. parser.add_argument('--wp_epoch', type=int, default=20,
  37. help='warmup epoch for finetune with MAE pretrained')
  38. parser.add_argument('--start_epoch', type=int, default=0,
  39. help='start epoch for finetune with MAE pretrained')
  40. parser.add_argument('--eval_epoch', type=int, default=10,
  41. help='warmup epoch for finetune with MAE pretrained')
  42. parser.add_argument('--max_epoch', type=int, default=200,
  43. help='max epoch')
  44. # Dataset
  45. parser.add_argument('--dataset', type=str, default='cifar10',
  46. help='dataset name')
  47. parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
  48. help='path to dataset folder')
  49. parser.add_argument('--num_classes', type=int, default=None,
  50. help='number of classes.')
  51. # Model
  52. parser.add_argument('-m', '--model', type=str, default='vit_t',
  53. help='model name')
  54. parser.add_argument('--resume', default=None, type=str,
  55. help='keep training')
  56. parser.add_argument('--drop_path', type=float, default=0.,
  57. help='drop_path')
  58. parser.add_argument('--mask_ratio', type=float, default=0.75,
  59. help='mask ratio.')
  60. # Optimizer
  61. parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
  62. help='sgd, adam')
  63. parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
  64. help='weight decay')
  65. parser.add_argument('--base_lr', type=float, default=0.00015,
  66. help='learning rate for training model')
  67. parser.add_argument('--min_lr', type=float, default=0,
  68. help='the final lr')
  69. # Optimizer
  70. parser.add_argument('-lrs', '--lr_scheduler', type=str, default='cosine',
  71. help='step, cosine')
  72. return parser.parse_args()
  73. def main():
  74. args = parse_args()
  75. # set random seed
  76. setup_seed(args.seed)
  77. # Path to save model
  78. path_to_save = os.path.join(args.path_to_save, args.dataset, "pretrained", args.model)
  79. os.makedirs(path_to_save, exist_ok=True)
  80. args.output_dir = path_to_save
  81. # ------------------------- Build CUDA -------------------------
  82. if args.cuda:
  83. if torch.cuda.is_available():
  84. cudnn.benchmark = True
  85. device = torch.device("cuda")
  86. else:
  87. print('There is no available GPU.')
  88. args.cuda = False
  89. device = torch.device("cpu")
  90. else:
  91. device = torch.device("cpu")
  92. # ------------------------- Build Dataset -------------------------
  93. train_dataset = build_dataset(args, is_train=True)
  94. # ------------------------- Build Dataloader -------------------------
  95. train_dataloader = build_dataloader(args, train_dataset, is_train=True)
  96. print('=================== Dataset Information ===================')
  97. print('Train dataset size : {}'.format(len(train_dataset)))
  98. # ------------------------- Build Model -------------------------
  99. model = build_model(args, model_type='mae')
  100. model.train().to(device)
  101. print(model)
  102. # ------------------------- Build Optimzier -------------------------
  103. optimizer = build_optimizer(args, model)
  104. # ------------------------- Build Lr Scheduler -------------------------
  105. lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
  106. lr_scheduler = build_lr_scheduler(args, optimizer)
  107. # ------------------------- Build checkpoint -------------------------
  108. load_model(args, model, optimizer, lr_scheduler)
  109. # ------------------------- Eval before Train Pipeline -------------------------
  110. if args.eval:
  111. print('visualizing ...')
  112. visualize(args, device, model)
  113. return
  114. # ------------------------- Training Pipeline -------------------------
  115. start_time = time.time()
  116. print("=================== Start training for {} epochs ===================".format(args.max_epoch))
  117. for epoch in range(args.start_epoch, args.max_epoch):
  118. # Train one epoch
  119. train_one_epoch(args, device, model, train_dataloader,
  120. optimizer, epoch, lr_scheduler_warmup)
  121. # LR scheduler
  122. if (epoch + 1) > args.wp_epoch:
  123. lr_scheduler.step()
  124. # Evaluate
  125. if epoch % args.eval_epoch == 0 or epoch + 1 == args.max_epoch:
  126. print('- saving the model after {} epochs ...'.format(epoch))
  127. save_model(args, epoch, model, optimizer, lr_scheduler, mae_task=True)
  128. total_time = time.time() - start_time
  129. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  130. print('Training time {}'.format(total_time_str))
  131. def visualize(args, device, model):
  132. # test dataset
  133. val_dataset = build_dataset(args, is_train=False)
  134. val_dataloader = build_dataloader(args, val_dataset, is_train=False)
  135. # save path
  136. save_path = "vis_results/{}/{}".format(args.dataset, args.model)
  137. os.makedirs(save_path, exist_ok=True)
  138. # switch to evaluate mode
  139. model.eval()
  140. patch_size = args.patch_size
  141. pixel_mean = val_dataloader.dataset.pixel_mean
  142. pixel_std = val_dataloader.dataset.pixel_std
  143. with torch.no_grad():
  144. for i, (images, target) in enumerate(val_dataloader):
  145. # To device
  146. images = images.to(device, non_blocking=True)
  147. target = target.to(device, non_blocking=True)
  148. # Inference
  149. output = model(images)
  150. # Denormalize input image
  151. org_img = images[0].permute(1, 2, 0).cpu().numpy()
  152. org_img = (org_img * pixel_std + pixel_mean) * 255.
  153. org_img = org_img.astype(np.uint8)
  154. # 调整mask的格式:[B, H*W] -> [B, H*W, p*p*3]
  155. mask = output['mask'].unsqueeze(-1).repeat(1, 1, patch_size**2 *3) # [B, H*W] -> [B, H*W, p*p*3]
  156. # 将序列格式的mask逆转回二维图像格式
  157. mask = unpatchify(mask, patch_size)
  158. mask = mask[0].permute(1, 2, 0).cpu().numpy()
  159. # 掩盖图像中被遮掩的图像patch区域
  160. masked_img = org_img * (1 - mask) # 1 is removing, 0 is keeping
  161. masked_img = masked_img.astype(np.uint8)
  162. # 将序列格式的重构图像逆转回二维图像格式
  163. pred_img = unpatchify(output['x_pred'], patch_size)
  164. pred_img = pred_img[0].permute(1, 2, 0).cpu().numpy()
  165. pred_img = (pred_img * pixel_std + pixel_mean) * 255.
  166. # 将原图中被保留的图像patch和网络预测的重构的图像patch拼在一起
  167. pred_img = org_img * (1 - mask) + pred_img * mask
  168. pred_img = pred_img.astype(np.uint8)
  169. # visualize
  170. vis_image = np.concatenate([masked_img, org_img, pred_img], axis=1)
  171. vis_image = vis_image[..., (2, 1, 0)]
  172. cv2.imshow('masked | origin | reconstruct ', vis_image)
  173. cv2.waitKey(0)
  174. # save
  175. cv2.imwrite('{}/{:06}.png'.format(save_path, i), vis_image)
  176. if __name__ == "__main__":
  177. main()