engine.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import math
  6. import numpy as np
  7. import random
  8. from utils import distributed_utils
  9. from utils.vis_tools import vis_data
  10. def refine_targets(targets, min_box_size):
  11. # rescale targets
  12. for tgt in targets:
  13. boxes = tgt["boxes"].clone()
  14. labels = tgt["labels"].clone()
  15. # refine tgt
  16. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  17. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  18. keep = (min_tgt_size >= min_box_size)
  19. tgt["boxes"] = boxes[keep]
  20. tgt["labels"] = labels[keep]
  21. return targets
  22. def rescale_image_targets(images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  23. """
  24. Deployed for Multi scale trick.
  25. """
  26. if isinstance(stride, int):
  27. max_stride = stride
  28. elif isinstance(stride, list):
  29. max_stride = max(stride)
  30. # During training phase, the shape of input image is square.
  31. old_img_size = images.shape[-1]
  32. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  33. new_img_size = new_img_size // max_stride * max_stride # size
  34. if new_img_size / old_img_size != 1:
  35. # interpolate
  36. images = torch.nn.functional.interpolate(
  37. input=images,
  38. size=new_img_size,
  39. mode='bilinear',
  40. align_corners=False)
  41. # rescale targets
  42. for tgt in targets:
  43. boxes = tgt["boxes"].clone()
  44. labels = tgt["labels"].clone()
  45. boxes = torch.clamp(boxes, 0, old_img_size)
  46. # rescale box
  47. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  48. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  49. # refine tgt
  50. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  51. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  52. keep = (min_tgt_size >= min_box_size)
  53. tgt["boxes"] = boxes[keep]
  54. tgt["labels"] = labels[keep]
  55. return images, targets, new_img_size
  56. def train_one_epoch(epoch,
  57. total_epochs,
  58. args,
  59. device,
  60. ema,
  61. model,
  62. criterion,
  63. cfg,
  64. dataloader,
  65. optimizer,
  66. scheduler,
  67. lf,
  68. scaler,
  69. last_opt_step):
  70. epoch_size = len(dataloader)
  71. img_size = args.img_size
  72. t0 = time.time()
  73. nw = epoch_size * args.wp_epoch
  74. accumulate = accumulate = max(1, round(64 / args.batch_size))
  75. # train one epoch
  76. for iter_i, (images, targets) in enumerate(dataloader):
  77. ni = iter_i + epoch * epoch_size
  78. # Warmup
  79. if ni <= nw:
  80. xi = [0, nw] # x interp
  81. accumulate = max(1, np.interp(ni, xi, [1, 64 / args.batch_size]).round())
  82. for j, x in enumerate(optimizer.param_groups):
  83. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  84. x['lr'] = np.interp(
  85. ni, xi, [cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
  86. if 'momentum' in x:
  87. x['momentum'] = np.interp(ni, xi, [cfg['warmup_momentum'], cfg['momentum']])
  88. # to device
  89. images = images.to(device, non_blocking=True).float() / 255.
  90. # multi scale
  91. if args.multi_scale:
  92. images, targets, img_size = rescale_image_targets(
  93. images, targets, model.stride, args.min_box_size, cfg['multi_scale'])
  94. else:
  95. targets = refine_targets(targets, args.min_box_size)
  96. # visualize train targets
  97. if args.vis_tgt:
  98. vis_data(images*255, targets)
  99. # inference
  100. with torch.cuda.amp.autocast(enabled=args.fp16):
  101. outputs = model(images)
  102. # loss
  103. loss_dict = criterion(outputs, targets, epoch)
  104. losses = loss_dict['losses']
  105. losses *= images.shape[0] # loss * bs
  106. # reduce
  107. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  108. if args.distributed:
  109. # gradient averaged between devices in DDP mode
  110. losses *= distributed_utils.get_world_size()
  111. # check loss
  112. try:
  113. if torch.isnan(losses):
  114. print('loss is NAN !!')
  115. continue
  116. except:
  117. print(loss_dict)
  118. # backward
  119. scaler.scale(losses).backward()
  120. # Optimize
  121. if ni - last_opt_step >= accumulate:
  122. if cfg['clip_grad'] > 0:
  123. # unscale gradients
  124. scaler.unscale_(optimizer)
  125. # clip gradients
  126. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg['clip_grad'])
  127. # optimizer.step
  128. scaler.step(optimizer)
  129. scaler.update()
  130. optimizer.zero_grad()
  131. # ema
  132. if ema:
  133. ema.update(model)
  134. last_opt_step = ni
  135. # display
  136. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  137. t1 = time.time()
  138. cur_lr = [param_group['lr'] for param_group in optimizer.param_groups]
  139. # basic infor
  140. log = '[Epoch: {}/{}]'.format(epoch+1, total_epochs)
  141. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  142. log += '[lr: {:.6f}]'.format(cur_lr[2])
  143. # loss infor
  144. for k in loss_dict_reduced.keys():
  145. if k == 'losses' and args.distributed:
  146. world_size = distributed_utils.get_world_size()
  147. log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
  148. else:
  149. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  150. # other infor
  151. log += '[time: {:.2f}]'.format(t1 - t0)
  152. log += '[size: {}]'.format(img_size)
  153. # print log infor
  154. print(log, flush=True)
  155. t0 = time.time()
  156. scheduler.step()
  157. return last_opt_step
  158. def val_one_epoch(args,
  159. model,
  160. evaluator,
  161. optimizer,
  162. epoch,
  163. best_map,
  164. path_to_save):
  165. if distributed_utils.is_main_process():
  166. # check evaluator
  167. if evaluator is None:
  168. print('No evaluator ... save model and go on training.')
  169. print('Saving state, epoch: {}'.format(epoch + 1))
  170. weight_name = '{}_no_eval.pth'.format(args.model)
  171. checkpoint_path = os.path.join(path_to_save, weight_name)
  172. torch.save({'model': model.state_dict(),
  173. 'mAP': -1.,
  174. 'optimizer': optimizer.state_dict(),
  175. 'epoch': epoch,
  176. 'args': args},
  177. checkpoint_path)
  178. else:
  179. print('eval ...')
  180. # set eval mode
  181. model.trainable = False
  182. model.eval()
  183. # evaluate
  184. evaluator.evaluate(model)
  185. cur_map = evaluator.map
  186. if cur_map > best_map:
  187. # update best-map
  188. best_map = cur_map
  189. # save model
  190. print('Saving state, epoch:', epoch + 1)
  191. weight_name = '{}_best.pth'.format(args.model)
  192. checkpoint_path = os.path.join(path_to_save, weight_name)
  193. torch.save({'model': model.state_dict(),
  194. 'mAP': round(best_map*100, 1),
  195. 'optimizer': optimizer.state_dict(),
  196. 'epoch': epoch,
  197. 'args': args},
  198. checkpoint_path)
  199. # set train mode.
  200. model.trainable = True
  201. model.train()
  202. if args.distributed:
  203. # wait for all processes to synchronize
  204. dist.barrier()
  205. return best_map