engine.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import math
  6. import numpy as np
  7. import random
  8. from utils import distributed_utils
  9. from utils.vis_tools import vis_data
  10. def rescale_image_targets(images, targets, max_stride, min_box_size):
  11. """
  12. Deployed for Multi scale trick.
  13. """
  14. # During training phase, the shape of input image is square.
  15. old_img_size = images.shape[-1]
  16. new_img_size = random.randrange(old_img_size * 0.5, old_img_size * 1.5 + max_stride) // max_stride * max_stride # size
  17. if new_img_size / old_img_size != 1:
  18. # interpolate
  19. images = torch.nn.functional.interpolate(
  20. input=images,
  21. size=new_img_size,
  22. mode='bilinear',
  23. align_corners=False)
  24. # rescale targets
  25. for tgt in targets:
  26. boxes = tgt["boxes"].clone()
  27. labels = tgt["labels"].clone()
  28. boxes = torch.clamp(boxes, 0, old_img_size)
  29. # rescale box
  30. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  31. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  32. # refine tgt
  33. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  34. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  35. keep = (min_tgt_size >= min_box_size)
  36. tgt["boxes"] = boxes[keep]
  37. tgt["labels"] = labels[keep]
  38. return images, targets, new_img_size
  39. def train_one_epoch(epoch,
  40. total_epochs,
  41. args,
  42. device,
  43. ema,
  44. model,
  45. criterion,
  46. cfg,
  47. dataloader,
  48. optimizer,
  49. scheduler,
  50. lf,
  51. scaler,
  52. last_opt_step):
  53. epoch_size = len(dataloader)
  54. img_size = args.img_size
  55. t0 = time.time()
  56. nw = epoch_size * args.wp_epoch
  57. accumulate = accumulate = max(1, round(64 / args.batch_size))
  58. # train one epoch
  59. for iter_i, (images, targets) in enumerate(dataloader):
  60. ni = iter_i + epoch * epoch_size
  61. # Warmup
  62. if ni <= nw:
  63. xi = [0, nw] # x interp
  64. accumulate = max(1, np.interp(ni, xi, [1, 64 / args.batch_size]).round())
  65. for j, x in enumerate(optimizer.param_groups):
  66. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  67. x['lr'] = np.interp(
  68. ni, xi, [cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
  69. if 'momentum' in x:
  70. x['momentum'] = np.interp(ni, xi, [cfg['warmup_momentum'], cfg['momentum']])
  71. # visualize train targets
  72. if args.vis_tgt:
  73. vis_data(images, targets)
  74. # to device
  75. images = images.to(device, non_blocking=True).float() / 255.
  76. # multi scale
  77. if args.multi_scale:
  78. images, targets, img_size = rescale_image_targets(
  79. images, targets, max(model.stride), args.min_box_size)
  80. # inference
  81. with torch.cuda.amp.autocast(enabled=args.fp16):
  82. outputs = model(images)
  83. # loss
  84. loss_dict = criterion(outputs=outputs, targets=targets)
  85. losses = loss_dict['losses']
  86. losses *= images.shape[0] # loss * bs
  87. # reduce
  88. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  89. if args.distributed:
  90. # gradient averaged between devices in DDP mode
  91. losses *= distributed_utils.get_world_size()
  92. # check loss
  93. try:
  94. if torch.isnan(losses):
  95. print('loss is NAN !!')
  96. continue
  97. except:
  98. print(loss_dict)
  99. # backward
  100. scaler.scale(losses).backward()
  101. # Optimize
  102. if ni - last_opt_step >= accumulate:
  103. if cfg['clip_grad'] > 0:
  104. # unscale gradients
  105. scaler.unscale_(optimizer)
  106. # clip gradients
  107. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg['clip_grad'])
  108. # optimizer.step
  109. scaler.step(optimizer)
  110. scaler.update()
  111. optimizer.zero_grad()
  112. # ema
  113. if ema:
  114. ema.update(model)
  115. last_opt_step = ni
  116. # display
  117. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  118. t1 = time.time()
  119. cur_lr = [param_group['lr'] for param_group in optimizer.param_groups]
  120. # basic infor
  121. log = '[Epoch: {}/{}]'.format(epoch+1, total_epochs)
  122. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  123. log += '[lr: {:.6f}]'.format(cur_lr[2])
  124. # loss infor
  125. for k in loss_dict_reduced.keys():
  126. if k == 'losses' and args.distributed:
  127. world_size = distributed_utils.get_world_size()
  128. log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
  129. else:
  130. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  131. # other infor
  132. log += '[time: {:.2f}]'.format(t1 - t0)
  133. log += '[size: {}]'.format(img_size)
  134. # print log infor
  135. print(log, flush=True)
  136. t0 = time.time()
  137. scheduler.step()
  138. return last_opt_step
  139. def val_one_epoch(args,
  140. model,
  141. evaluator,
  142. optimizer,
  143. epoch,
  144. best_map,
  145. path_to_save):
  146. # check evaluator
  147. if distributed_utils.is_main_process():
  148. if evaluator is None:
  149. print('No evaluator ... save model and go on training.')
  150. print('Saving state, epoch: {}'.format(epoch + 1))
  151. weight_name = '{}_epoch_{}.pth'.format(args.version, epoch + 1)
  152. checkpoint_path = os.path.join(path_to_save, weight_name)
  153. torch.save({'model': model.state_dict(),
  154. 'mAP': -1.,
  155. 'optimizer': optimizer.state_dict(),
  156. 'epoch': epoch,
  157. 'args': args},
  158. checkpoint_path)
  159. else:
  160. print('eval ...')
  161. # set eval mode
  162. model.trainable = False
  163. model.eval()
  164. # evaluate
  165. evaluator.evaluate(model)
  166. cur_map = evaluator.map
  167. if cur_map > best_map:
  168. # update best-map
  169. best_map = cur_map
  170. # save model
  171. print('Saving state, epoch:', epoch + 1)
  172. weight_name = '{}_epoch_{}_{:.2f}.pth'.format(args.version, epoch + 1, best_map*100)
  173. checkpoint_path = os.path.join(path_to_save, weight_name)
  174. torch.save({'model': model.state_dict(),
  175. 'mAP': round(best_map*100, 1),
  176. 'optimizer': optimizer.state_dict(),
  177. 'epoch': epoch,
  178. 'args': args},
  179. checkpoint_path)
  180. # set train mode.
  181. model.trainable = True
  182. model.train()
  183. if args.distributed:
  184. # wait for all processes to synchronize
  185. dist.barrier()
  186. return best_map