train.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. from copy import deepcopy
  2. import os
  3. import time
  4. import math
  5. import argparse
  6. import datetime
  7. # ---------------- Timm compoments ----------------
  8. from timm.data.mixup import Mixup
  9. from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
  10. # ---------------- Torch compoments ----------------
  11. import torch
  12. import torch.backends.cudnn as cudnn
  13. import torch.distributed as dist
  14. from torch.nn.parallel import DistributedDataParallel as DDP
  15. # ---------------- Dataset compoments ----------------
  16. from data import build_dataset, build_dataloader
  17. # ---------------- Model compoments ----------------
  18. from models import build_model
  19. # ---------------- Utils compoments ----------------
  20. from utils import distributed_utils
  21. from utils.ema import ModelEMA
  22. from utils.misc import setup_seed, print_rank_0, load_model, save_model
  23. from utils.misc import NativeScalerWithGradNormCount as NativeScaler
  24. from utils.optimzer import build_optimizer
  25. from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
  26. from utils.com_flops_params import FLOPs_and_Params
  27. # ---------------- Training engine ----------------
  28. from engine import train_one_epoch, evaluate
  29. def parse_args():
  30. parser = argparse.ArgumentParser()
  31. # Input
  32. parser.add_argument('--img_size', type=int, default=224,
  33. help='input image size.')
  34. parser.add_argument('--img_dim', type=int, default=3,
  35. help='3 for RGB; 1 for Gray.')
  36. parser.add_argument('--num_classes', type=int, default=1000,
  37. help='Number of the classes.')
  38. # Basic
  39. parser.add_argument('--seed', type=int, default=42,
  40. help='random seed.')
  41. parser.add_argument('--cuda', action='store_true', default=False,
  42. help='use cuda')
  43. parser.add_argument('--batch_size', type=int, default=256,
  44. help='batch size on all GPUs')
  45. parser.add_argument('--num_workers', type=int, default=4,
  46. help='number of workers')
  47. parser.add_argument('--path_to_save', type=str, default='weights/',
  48. help='path to save trained model.')
  49. parser.add_argument('--tfboard', action='store_true', default=False,
  50. help='use tensorboard')
  51. parser.add_argument('--eval', action='store_true', default=False,
  52. help='evaluate model.')
  53. # Epoch
  54. parser.add_argument('--wp_epoch', type=int, default=20,
  55. help='warmup epoch for finetune with MAE pretrained')
  56. parser.add_argument('--start_epoch', type=int, default=0,
  57. help='start epoch for finetune with MAE pretrained')
  58. parser.add_argument('--max_epoch', type=int, default=300,
  59. help='max epoch')
  60. parser.add_argument('--eval_epoch', type=int, default=10,
  61. help='max epoch')
  62. # Dataset
  63. parser.add_argument('--dataset', type=str, default='cifar10',
  64. help='dataset name')
  65. parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
  66. help='path to dataset folder')
  67. # Model
  68. parser.add_argument('-m', '--model', type=str, default='rtcnet_n',
  69. help='model name')
  70. parser.add_argument('--resume', default=None, type=str,
  71. help='keep training')
  72. parser.add_argument('--ema', action='store_true', default=False,
  73. help='use ema.')
  74. parser.add_argument('--drop_path', type=float, default=0.1,
  75. help='drop_path')
  76. # Optimizer
  77. parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
  78. help='sgd, adam')
  79. parser.add_argument('-lrs', '--lr_scheduler', type=str, default='step',
  80. help='cosine, step')
  81. parser.add_argument('-mt', '--momentum', type=float, default=0.9,
  82. help='weight decay')
  83. parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
  84. help='weight decay')
  85. parser.add_argument('--batch_base', type=int, default=256,
  86. help='gradient accumulation')
  87. parser.add_argument('--base_lr', type=float, default=1e-3,
  88. help='learning rate for training model')
  89. parser.add_argument('--min_lr', type=float, default=1e-6,
  90. help='the final lr')
  91. parser.add_argument('--grad_accumulate', type=int, default=1,
  92. help='gradient accumulation')
  93. parser.add_argument('--max_grad_norm', type=float, default=None,
  94. help='Clip gradient norm (default: None, no clipping)')
  95. # Augmentation parameters
  96. parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
  97. help='Color jitter factor (enabled only when not using Auto/RandAug)')
  98. parser.add_argument('--aa', type=str, default=None, metavar='NAME',
  99. help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
  100. parser.add_argument('--smoothing', type=float, default=0.1,
  101. help='Label smoothing (default: 0.1)')
  102. # Random Erase params
  103. parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
  104. help='Random erase prob (default: 0.25)')
  105. parser.add_argument('--remode', type=str, default='pixel',
  106. help='Random erase mode (default: "pixel")')
  107. parser.add_argument('--recount', type=int, default=1,
  108. help='Random erase count (default: 1)')
  109. parser.add_argument('--resplit', action='store_true', default=False,
  110. help='Do not random erase first (clean) augmentation split')
  111. # Mixup params
  112. parser.add_argument('--mixup', type=float, default=0,
  113. help='mixup alpha, mixup enabled if > 0.')
  114. parser.add_argument('--cutmix', type=float, default=0,
  115. help='cutmix alpha, cutmix enabled if > 0.')
  116. parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
  117. help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
  118. parser.add_argument('--mixup_prob', type=float, default=1.0,
  119. help='Probability of performing mixup or cutmix when either/both is enabled')
  120. parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
  121. help='Probability of switching to cutmix when both mixup and cutmix enabled')
  122. parser.add_argument('--mixup_mode', type=str, default='batch',
  123. help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
  124. # DDP
  125. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  126. help='distributed training')
  127. parser.add_argument('--dist_url', default='env://',
  128. help='url used to set up distributed training')
  129. parser.add_argument('--world_size', default=1, type=int,
  130. help='number of distributed processes')
  131. parser.add_argument('--sybn', action='store_true', default=False,
  132. help='use sybn.')
  133. parser.add_argument('--local_rank', default=-1, type=int,
  134. help='the number of local rank.')
  135. return parser.parse_args()
  136. def main():
  137. args = parse_args()
  138. # set random seed
  139. setup_seed(args.seed)
  140. # Path to save model
  141. path_to_save = os.path.join(args.path_to_save, args.dataset, args.model)
  142. os.makedirs(path_to_save, exist_ok=True)
  143. args.output_dir = path_to_save
  144. # ------------------------- Build DDP environment -------------------------
  145. ## LOCAL_RANK is the global GPU number tag, the value range is [0, world_size - 1].
  146. ## LOCAL_PROCESS_RANK is the number of the GPU of each machine, not global.
  147. local_rank = local_process_rank = -1
  148. if args.distributed:
  149. distributed_utils.init_distributed_mode(args)
  150. print("git:\n {}\n".format(distributed_utils.get_sha()))
  151. try:
  152. # Multiple Mechine & Multiple GPUs (world size > 8)
  153. local_rank = torch.distributed.get_rank()
  154. local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0'))
  155. except:
  156. # Single Mechine & Multiple GPUs (world size <= 8)
  157. local_rank = local_process_rank = torch.distributed.get_rank()
  158. print_rank_0(args)
  159. args.world_size = distributed_utils.get_world_size()
  160. print('World size: {}'.format(distributed_utils.get_world_size()))
  161. print("LOCAL RANK: ", local_rank)
  162. print("LOCAL_PROCESS_RANL: ", local_process_rank)
  163. # ------------------------- Build CUDA -------------------------
  164. if args.cuda:
  165. if torch.cuda.is_available():
  166. cudnn.benchmark = True
  167. device = torch.device("cuda")
  168. else:
  169. print('There is no available GPU.')
  170. args.cuda = False
  171. device = torch.device("cpu")
  172. else:
  173. device = torch.device("cpu")
  174. # ------------------------- Build Tensorboard -------------------------
  175. tblogger = None
  176. if local_rank <= 0 and args.tfboard:
  177. print('use tensorboard')
  178. from torch.utils.tensorboard import SummaryWriter
  179. time_stamp = time.strftime('%Y-%m-%d_%H:%M:%S',time.localtime(time.time()))
  180. log_path = os.path.join('log/', args.dataset, time_stamp)
  181. os.makedirs(log_path, exist_ok=True)
  182. tblogger = SummaryWriter(log_path)
  183. # ------------------------- Build Dataset -------------------------
  184. train_dataset = build_dataset(args, is_train=True)
  185. val_dataset = build_dataset(args, is_train=False)
  186. # ------------------------- Build Dataloader -------------------------
  187. train_dataloader = build_dataloader(args, train_dataset, is_train=True)
  188. val_dataloader = build_dataloader(args, val_dataset, is_train=False)
  189. print('=================== Dataset Information ===================')
  190. print("Dataset: ", args.dataset)
  191. print('- train dataset size : ', len(train_dataset))
  192. print('- val dataset size : ', len(val_dataset))
  193. # ------------------------- Build Model -------------------------
  194. model = build_model(args)
  195. model.train().to(device)
  196. print(model)
  197. if local_rank <= 0:
  198. model_copy = deepcopy(model)
  199. model_copy.eval()
  200. FLOPs_and_Params(model_copy, args.img_size, args.img_dim, device)
  201. model_copy.train()
  202. del model_copy
  203. if args.distributed:
  204. # wait for all processes to synchronize
  205. dist.barrier()
  206. # ------------------------- Build DDP Model -------------------------
  207. model_without_ddp = model
  208. if args.distributed:
  209. model = DDP(model, device_ids=[args.gpu])
  210. if args.sybn:
  211. print('use SyncBatchNorm ...')
  212. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  213. model_without_ddp = model.module
  214. # ------------------------- Mixup augmentation config -------------------------
  215. mixup_fn = None
  216. mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
  217. if mixup_active:
  218. print_rank_0("Mixup is activated!", local_rank)
  219. mixup_fn = Mixup(mixup_alpha = args.mixup,
  220. cutmix_alpha = args.cutmix,
  221. cutmix_minmax = args.cutmix_minmax,
  222. prob = args.mixup_prob,
  223. switch_prob = args.mixup_switch_prob,
  224. mode = args.mixup_mode,
  225. label_smoothing = args.smoothing,
  226. num_classes = args.num_classes)
  227. # ------------------------- Build Optimzier -------------------------
  228. optimizer = build_optimizer(args, model_without_ddp)
  229. loss_scaler = NativeScaler()
  230. # ------------------------- Build Lr Scheduler -------------------------
  231. lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
  232. lr_scheduler = build_lr_scheduler(args, optimizer)
  233. # ------------------------- Build Criterion -------------------------
  234. if mixup_fn is not None:
  235. # smoothing is handled with mixup label transform
  236. criterion = SoftTargetCrossEntropy()
  237. elif args.smoothing > 0.:
  238. criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
  239. else:
  240. criterion = torch.nn.CrossEntropyLoss()
  241. load_model(args=args, model_without_ddp=model_without_ddp,
  242. optimizer=optimizer, lr_scheduler=lr_scheduler, loss_scaler=loss_scaler)
  243. # ------------------------- Build Model-EMA -------------------------
  244. if args.ema:
  245. print("Build model ema for {}".format(args.model))
  246. updates = args.start_epoch * len(train_dataloader) // args.grad_accumulate
  247. print("Initialial updates of ModelEMA: {}".format(updates))
  248. model_ema = ModelEMA(model_without_ddp, ema_decay=0.999, ema_tau=2000., updates=updates)
  249. else:
  250. model_ema = None
  251. # ------------------------- Eval before Train Pipeline -------------------------
  252. if args.eval:
  253. print('evaluating ...')
  254. test_stats = evaluate(val_dataloader, model_without_ddp, device, local_rank)
  255. print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' %
  256. (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True)
  257. return
  258. # ------------------------- Training Pipeline -------------------------
  259. start_time = time.time()
  260. max_accuracy = -1.0
  261. print_rank_0("=============== Start training for {} epochs ===============".format(args.max_epoch), local_rank)
  262. for epoch in range(args.start_epoch, args.max_epoch):
  263. if args.distributed:
  264. train_dataloader.batch_sampler.sampler.set_epoch(epoch)
  265. # train one epoch
  266. train_one_epoch(args, device, model, model_ema, train_dataloader, optimizer, epoch,
  267. lr_scheduler_warmup, loss_scaler, criterion, local_rank, tblogger, mixup_fn)
  268. # LR scheduler
  269. if (epoch + 1) > args.wp_epoch:
  270. lr_scheduler.step()
  271. # Evaluate
  272. if local_rank <= 0:
  273. model_eval = model_ema.ema if model_ema is not None else model_without_ddp
  274. if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch):
  275. print_rank_0("Evaluating ...")
  276. test_stats = evaluate(val_dataloader, model_eval, device, local_rank)
  277. print_rank_0(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%", local_rank)
  278. max_accuracy = max(max_accuracy, test_stats["acc1"])
  279. print_rank_0(f'Max accuracy: {max_accuracy:.2f}%', local_rank)
  280. # Save model
  281. print('- saving the model after {} epochs ...'.format(epoch))
  282. save_model(args=args, model=model_eval, model_without_ddp=model_eval,
  283. optimizer=optimizer, lr_scheduler=lr_scheduler, loss_scaler=loss_scaler, epoch=epoch, acc1=max_accuracy)
  284. if args.distributed:
  285. dist.barrier()
  286. if tblogger is not None:
  287. tblogger.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
  288. tblogger.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
  289. tblogger.add_scalar('perf/test_loss', test_stats['loss'], epoch)
  290. if args.distributed:
  291. dist.barrier()
  292. total_time = time.time() - start_time
  293. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  294. print('Training time {}'.format(total_time_str))
  295. if __name__ == "__main__":
  296. main()