engine.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import math
  6. import numpy as np
  7. import random
  8. from utils import distributed_utils
  9. from utils.vis_tools import vis_data
  10. def rescale_image_targets(images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
  11. """
  12. Deployed for Multi scale trick.
  13. """
  14. if isinstance(stride, int):
  15. max_stride = stride
  16. elif isinstance(stride, list):
  17. max_stride = max(stride)
  18. # During training phase, the shape of input image is square.
  19. old_img_size = images.shape[-1]
  20. new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
  21. new_img_size = new_img_size // max_stride * max_stride # size
  22. if new_img_size / old_img_size != 1:
  23. # interpolate
  24. images = torch.nn.functional.interpolate(
  25. input=images,
  26. size=new_img_size,
  27. mode='bilinear',
  28. align_corners=False)
  29. # rescale targets
  30. for tgt in targets:
  31. boxes = tgt["boxes"].clone()
  32. labels = tgt["labels"].clone()
  33. boxes = torch.clamp(boxes, 0, old_img_size)
  34. # rescale box
  35. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  36. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  37. # refine tgt
  38. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  39. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  40. keep = (min_tgt_size >= min_box_size)
  41. tgt["boxes"] = boxes[keep]
  42. tgt["labels"] = labels[keep]
  43. return images, targets, new_img_size
  44. def train_one_epoch(epoch,
  45. total_epochs,
  46. args,
  47. device,
  48. ema,
  49. model,
  50. criterion,
  51. cfg,
  52. dataloader,
  53. optimizer,
  54. scheduler,
  55. lf,
  56. scaler,
  57. last_opt_step):
  58. epoch_size = len(dataloader)
  59. img_size = args.img_size
  60. t0 = time.time()
  61. nw = epoch_size * args.wp_epoch
  62. accumulate = accumulate = max(1, round(64 / args.batch_size))
  63. # train one epoch
  64. for iter_i, (images, targets) in enumerate(dataloader):
  65. ni = iter_i + epoch * epoch_size
  66. # Warmup
  67. if ni <= nw:
  68. xi = [0, nw] # x interp
  69. accumulate = max(1, np.interp(ni, xi, [1, 64 / args.batch_size]).round())
  70. for j, x in enumerate(optimizer.param_groups):
  71. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  72. x['lr'] = np.interp(
  73. ni, xi, [cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
  74. if 'momentum' in x:
  75. x['momentum'] = np.interp(ni, xi, [cfg['warmup_momentum'], cfg['momentum']])
  76. # visualize train targets
  77. if args.vis_tgt:
  78. vis_data(images, targets)
  79. # to device
  80. images = images.to(device, non_blocking=True).float() / 255.
  81. # multi scale
  82. if args.multi_scale:
  83. images, targets, img_size = rescale_image_targets(
  84. images, targets, model.stride, args.min_box_size, cfg['multi_scale'])
  85. # inference
  86. with torch.cuda.amp.autocast(enabled=args.fp16):
  87. outputs = model(images)
  88. # loss
  89. loss_dict = criterion(outputs=outputs, targets=targets)
  90. losses = loss_dict['losses']
  91. losses *= images.shape[0] # loss * bs
  92. # reduce
  93. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  94. if args.distributed:
  95. # gradient averaged between devices in DDP mode
  96. losses *= distributed_utils.get_world_size()
  97. # check loss
  98. try:
  99. if torch.isnan(losses):
  100. print('loss is NAN !!')
  101. continue
  102. except:
  103. print(loss_dict)
  104. # backward
  105. scaler.scale(losses).backward()
  106. # Optimize
  107. if ni - last_opt_step >= accumulate:
  108. if cfg['clip_grad'] > 0:
  109. # unscale gradients
  110. scaler.unscale_(optimizer)
  111. # clip gradients
  112. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg['clip_grad'])
  113. # optimizer.step
  114. scaler.step(optimizer)
  115. scaler.update()
  116. optimizer.zero_grad()
  117. # ema
  118. if ema:
  119. ema.update(model)
  120. last_opt_step = ni
  121. # display
  122. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  123. t1 = time.time()
  124. cur_lr = [param_group['lr'] for param_group in optimizer.param_groups]
  125. # basic infor
  126. log = '[Epoch: {}/{}]'.format(epoch+1, total_epochs)
  127. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  128. log += '[lr: {:.6f}]'.format(cur_lr[2])
  129. # loss infor
  130. for k in loss_dict_reduced.keys():
  131. if k == 'losses' and args.distributed:
  132. world_size = distributed_utils.get_world_size()
  133. log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
  134. else:
  135. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  136. # other infor
  137. log += '[time: {:.2f}]'.format(t1 - t0)
  138. log += '[size: {}]'.format(img_size)
  139. # print log infor
  140. print(log, flush=True)
  141. t0 = time.time()
  142. scheduler.step()
  143. return last_opt_step
  144. def val_one_epoch(args,
  145. model,
  146. evaluator,
  147. optimizer,
  148. epoch,
  149. best_map,
  150. path_to_save):
  151. # check evaluator
  152. if distributed_utils.is_main_process():
  153. if evaluator is None:
  154. print('No evaluator ... save model and go on training.')
  155. print('Saving state, epoch: {}'.format(epoch + 1))
  156. weight_name = '{}_epoch_{}.pth'.format(args.model, epoch + 1)
  157. checkpoint_path = os.path.join(path_to_save, weight_name)
  158. torch.save({'model': model.state_dict(),
  159. 'mAP': -1.,
  160. 'optimizer': optimizer.state_dict(),
  161. 'epoch': epoch,
  162. 'args': args},
  163. checkpoint_path)
  164. else:
  165. print('eval ...')
  166. # set eval mode
  167. model.trainable = False
  168. model.eval()
  169. # evaluate
  170. evaluator.evaluate(model)
  171. cur_map = evaluator.map
  172. if cur_map > best_map:
  173. # update best-map
  174. best_map = cur_map
  175. # save model
  176. print('Saving state, epoch:', epoch + 1)
  177. weight_name = '{}_epoch_{}_{:.2f}.pth'.format(args.model, epoch + 1, best_map*100)
  178. checkpoint_path = os.path.join(path_to_save, weight_name)
  179. torch.save({'model': model.state_dict(),
  180. 'mAP': round(best_map*100, 1),
  181. 'optimizer': optimizer.state_dict(),
  182. 'epoch': epoch,
  183. 'args': args},
  184. checkpoint_path)
  185. # set train mode.
  186. model.trainable = True
  187. model.train()
  188. if args.distributed:
  189. # wait for all processes to synchronize
  190. dist.barrier()
  191. return best_map