engine.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import math
  6. import numpy as np
  7. import random
  8. from utils import distributed_utils
  9. from utils.vis_tools import vis_data
  10. def rescale_image_targets(images, targets, stride, min_box_size):
  11. """
  12. Deployed for Multi scale trick.
  13. """
  14. if isinstance(stride, int):
  15. max_stride = stride
  16. elif isinstance(stride, list):
  17. max_stride = max(stride)
  18. # During training phase, the shape of input image is square.
  19. old_img_size = images.shape[-1]
  20. new_img_size = random.randrange(old_img_size * 0.5, old_img_size * 1.5 + max_stride) // max_stride * max_stride # size
  21. if new_img_size / old_img_size != 1:
  22. # interpolate
  23. images = torch.nn.functional.interpolate(
  24. input=images,
  25. size=new_img_size,
  26. mode='bilinear',
  27. align_corners=False)
  28. # rescale targets
  29. for tgt in targets:
  30. boxes = tgt["boxes"].clone()
  31. labels = tgt["labels"].clone()
  32. boxes = torch.clamp(boxes, 0, old_img_size)
  33. # rescale box
  34. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  35. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  36. # refine tgt
  37. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  38. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  39. keep = (min_tgt_size >= min_box_size)
  40. tgt["boxes"] = boxes[keep]
  41. tgt["labels"] = labels[keep]
  42. return images, targets, new_img_size
  43. def train_one_epoch(epoch,
  44. total_epochs,
  45. args,
  46. device,
  47. ema,
  48. model,
  49. criterion,
  50. cfg,
  51. dataloader,
  52. optimizer,
  53. scheduler,
  54. lf,
  55. scaler,
  56. last_opt_step):
  57. epoch_size = len(dataloader)
  58. img_size = args.img_size
  59. t0 = time.time()
  60. nw = epoch_size * args.wp_epoch
  61. accumulate = accumulate = max(1, round(64 / args.batch_size))
  62. # train one epoch
  63. for iter_i, (images, targets) in enumerate(dataloader):
  64. ni = iter_i + epoch * epoch_size
  65. # Warmup
  66. if ni <= nw:
  67. xi = [0, nw] # x interp
  68. accumulate = max(1, np.interp(ni, xi, [1, 64 / args.batch_size]).round())
  69. for j, x in enumerate(optimizer.param_groups):
  70. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  71. x['lr'] = np.interp(
  72. ni, xi, [cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
  73. if 'momentum' in x:
  74. x['momentum'] = np.interp(ni, xi, [cfg['warmup_momentum'], cfg['momentum']])
  75. # visualize train targets
  76. if args.vis_tgt:
  77. vis_data(images, targets)
  78. # to device
  79. images = images.to(device, non_blocking=True).float() / 255.
  80. # multi scale
  81. if args.multi_scale:
  82. images, targets, img_size = rescale_image_targets(
  83. images, targets, model.stride, args.min_box_size)
  84. # inference
  85. with torch.cuda.amp.autocast(enabled=args.fp16):
  86. outputs = model(images)
  87. # loss
  88. loss_dict = criterion(outputs=outputs, targets=targets)
  89. losses = loss_dict['losses']
  90. losses *= images.shape[0] # loss * bs
  91. # reduce
  92. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  93. if args.distributed:
  94. # gradient averaged between devices in DDP mode
  95. losses *= distributed_utils.get_world_size()
  96. # check loss
  97. try:
  98. if torch.isnan(losses):
  99. print('loss is NAN !!')
  100. continue
  101. except:
  102. print(loss_dict)
  103. # backward
  104. scaler.scale(losses).backward()
  105. # Optimize
  106. if ni - last_opt_step >= accumulate:
  107. if cfg['clip_grad'] > 0:
  108. # unscale gradients
  109. scaler.unscale_(optimizer)
  110. # clip gradients
  111. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg['clip_grad'])
  112. # optimizer.step
  113. scaler.step(optimizer)
  114. scaler.update()
  115. optimizer.zero_grad()
  116. # ema
  117. if ema:
  118. ema.update(model)
  119. last_opt_step = ni
  120. # display
  121. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  122. t1 = time.time()
  123. cur_lr = [param_group['lr'] for param_group in optimizer.param_groups]
  124. # basic infor
  125. log = '[Epoch: {}/{}]'.format(epoch+1, total_epochs)
  126. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  127. log += '[lr: {:.6f}]'.format(cur_lr[2])
  128. # loss infor
  129. for k in loss_dict_reduced.keys():
  130. if k == 'losses' and args.distributed:
  131. world_size = distributed_utils.get_world_size()
  132. log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
  133. else:
  134. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  135. # other infor
  136. log += '[time: {:.2f}]'.format(t1 - t0)
  137. log += '[size: {}]'.format(img_size)
  138. # print log infor
  139. print(log, flush=True)
  140. t0 = time.time()
  141. scheduler.step()
  142. return last_opt_step
  143. def val_one_epoch(args,
  144. model,
  145. evaluator,
  146. optimizer,
  147. epoch,
  148. best_map,
  149. path_to_save):
  150. # check evaluator
  151. if distributed_utils.is_main_process():
  152. if evaluator is None:
  153. print('No evaluator ... save model and go on training.')
  154. print('Saving state, epoch: {}'.format(epoch + 1))
  155. weight_name = '{}_epoch_{}.pth'.format(args.model, epoch + 1)
  156. checkpoint_path = os.path.join(path_to_save, weight_name)
  157. torch.save({'model': model.state_dict(),
  158. 'mAP': -1.,
  159. 'optimizer': optimizer.state_dict(),
  160. 'epoch': epoch,
  161. 'args': args},
  162. checkpoint_path)
  163. else:
  164. print('eval ...')
  165. # set eval mode
  166. model.trainable = False
  167. model.eval()
  168. # evaluate
  169. evaluator.evaluate(model)
  170. cur_map = evaluator.map
  171. if cur_map > best_map:
  172. # update best-map
  173. best_map = cur_map
  174. # save model
  175. print('Saving state, epoch:', epoch + 1)
  176. weight_name = '{}_epoch_{}_{:.2f}.pth'.format(args.model, epoch + 1, best_map*100)
  177. checkpoint_path = os.path.join(path_to_save, weight_name)
  178. torch.save({'model': model.state_dict(),
  179. 'mAP': round(best_map*100, 1),
  180. 'optimizer': optimizer.state_dict(),
  181. 'epoch': epoch,
  182. 'args': args},
  183. checkpoint_path)
  184. # set train mode.
  185. model.trainable = True
  186. model.train()
  187. if args.distributed:
  188. # wait for all processes to synchronize
  189. dist.barrier()
  190. return best_map