engine.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import random
  6. from utils import distributed_utils
  7. from utils.vis_tools import vis_data
  8. def rescale_image_targets(images, targets, max_stride, min_box_size):
  9. """
  10. Deployed for Multi scale trick.
  11. """
  12. # During training phase, the shape of input image is square.
  13. old_img_size = images.shape[-1]
  14. new_img_size = random.randrange(old_img_size * 0.5, old_img_size * 1.0 + max_stride) // max_stride * max_stride # size
  15. if new_img_size / old_img_size != 1:
  16. # interpolate
  17. images = torch.nn.functional.interpolate(
  18. input=images,
  19. size=new_img_size,
  20. mode='bilinear',
  21. align_corners=False)
  22. # rescale targets
  23. for tgt in targets:
  24. boxes = tgt["boxes"].clone()
  25. labels = tgt["labels"].clone()
  26. boxes = torch.clamp(boxes, 0, old_img_size)
  27. # rescale box
  28. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  29. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  30. # refine tgt
  31. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  32. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  33. keep = (min_tgt_size >= min_box_size)
  34. tgt["boxes"] = boxes[keep]
  35. tgt["labels"] = labels[keep]
  36. return images, targets, new_img_size
  37. def train_one_epoch(epoch,
  38. total_epochs,
  39. args,
  40. device,
  41. ema,
  42. model,
  43. criterion,
  44. dataloader,
  45. optimizer,
  46. lr_scheduler,
  47. warmup_scheduler,
  48. last_opt_step):
  49. epoch_size = len(dataloader)
  50. img_size = args.img_size
  51. t0 = time.time()
  52. nw = epoch_size * args.wp_epoch
  53. accumulate = accumulate = max(1, round(64 / args.batch_size))
  54. # train one epoch
  55. for iter_i, (images, targets) in enumerate(dataloader):
  56. ni = iter_i + epoch * epoch_size
  57. # Warmup
  58. if ni <= nw:
  59. warmup_scheduler.warmup(ni, optimizer)
  60. # visualize train targets
  61. if args.vis_tgt:
  62. vis_data(images, targets)
  63. # to device
  64. images = images.to(device, non_blocking=True).float()
  65. # multi scale
  66. if args.multi_scale:
  67. images, targets, img_size = rescale_image_targets(
  68. images, targets, max(model.stride), args.min_box_size)
  69. # inference
  70. outputs = model(images)
  71. # loss
  72. loss_dict = criterion(outputs=outputs, targets=targets)
  73. losses = loss_dict['losses']
  74. # reduce
  75. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  76. if args.distributed:
  77. # gradient averaged between devices in DDP mode
  78. losses *= distributed_utils.get_world_size()
  79. # check loss
  80. try:
  81. if torch.isnan(losses):
  82. print('loss is NAN !!')
  83. continue
  84. except:
  85. print(loss_dict)
  86. # backward
  87. losses.backward()
  88. # Optimize
  89. if ni - last_opt_step >= accumulate:
  90. # optimizer.step
  91. optimizer.step()
  92. optimizer.zero_grad()
  93. # EMA
  94. if ema:
  95. ema.update(model)
  96. last_opt_step = ni
  97. # display
  98. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  99. t1 = time.time()
  100. cur_lr = [param_group['lr'] for param_group in optimizer.param_groups]
  101. # basic infor
  102. log = '[Epoch: {}/{}]'.format(epoch+1, total_epochs)
  103. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  104. log += '[lr: {:.6f}]'.format(cur_lr[0])
  105. # loss infor
  106. for k in loss_dict_reduced.keys():
  107. if k == 'losses' and args.distributed:
  108. world_size = distributed_utils.get_world_size()
  109. log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
  110. else:
  111. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  112. # other infor
  113. log += '[time: {:.2f}]'.format(t1 - t0)
  114. log += '[size: {}]'.format(img_size)
  115. # print log infor
  116. print(log, flush=True)
  117. t0 = time.time()
  118. lr_scheduler.step()
  119. return last_opt_step
  120. def val_one_epoch(args,
  121. model,
  122. evaluator,
  123. optimizer,
  124. epoch,
  125. best_map,
  126. path_to_save):
  127. # check evaluator
  128. if distributed_utils.is_main_process():
  129. if evaluator is None:
  130. print('No evaluator ... save model and go on training.')
  131. print('Saving state, epoch: {}'.format(epoch + 1))
  132. weight_name = '{}_epoch_{}.pth'.format(args.version, epoch + 1)
  133. checkpoint_path = os.path.join(path_to_save, weight_name)
  134. torch.save({'model': model.state_dict(),
  135. 'mAP': -1.,
  136. 'optimizer': optimizer.state_dict(),
  137. 'epoch': epoch,
  138. 'args': args},
  139. checkpoint_path)
  140. else:
  141. print('eval ...')
  142. # set eval mode
  143. model.trainable = False
  144. model.eval()
  145. # evaluate
  146. evaluator.evaluate(model)
  147. cur_map = evaluator.map
  148. if cur_map > best_map:
  149. # update best-map
  150. best_map = cur_map
  151. # save model
  152. print('Saving state, epoch:', epoch + 1)
  153. weight_name = '{}_epoch_{}_{:.2f}.pth'.format(args.version, epoch + 1, best_map*100)
  154. checkpoint_path = os.path.join(path_to_save, weight_name)
  155. torch.save({'model': model.state_dict(),
  156. 'mAP': round(best_map*100, 1),
  157. 'optimizer': optimizer.state_dict(),
  158. 'epoch': epoch,
  159. 'args': args},
  160. checkpoint_path)
  161. # set train mode.
  162. model.trainable = True
  163. model.train()
  164. if args.distributed:
  165. # wait for all processes to synchronize
  166. dist.barrier()
  167. return best_map