engine.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import torch
  2. import torch.distributed as dist
  3. import time
  4. import os
  5. import random
  6. from utils import distributed_utils
  7. from utils.vis_tools import vis_data
  8. def rescale_image_targets(images, targets, stride, min_box_size):
  9. """
  10. Deployed for Multi scale trick.
  11. """
  12. if isinstance(stride, int):
  13. max_stride = stride
  14. elif isinstance(stride, list):
  15. max_stride = max(stride)
  16. # During training phase, the shape of input image is square.
  17. old_img_size = images.shape[-1]
  18. new_img_size = random.randrange(old_img_size * 0.5, old_img_size * 1.0 + max_stride) // max_stride * max_stride # size
  19. if new_img_size / old_img_size != 1:
  20. # interpolate
  21. images = torch.nn.functional.interpolate(
  22. input=images,
  23. size=new_img_size,
  24. mode='bilinear',
  25. align_corners=False)
  26. # rescale targets
  27. for tgt in targets:
  28. boxes = tgt["boxes"].clone()
  29. labels = tgt["labels"].clone()
  30. boxes = torch.clamp(boxes, 0, old_img_size)
  31. # rescale box
  32. boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
  33. boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
  34. # refine tgt
  35. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  36. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  37. keep = (min_tgt_size >= min_box_size)
  38. tgt["boxes"] = boxes[keep]
  39. tgt["labels"] = labels[keep]
  40. return images, targets, new_img_size
  41. def train_one_epoch(epoch,
  42. total_epochs,
  43. args,
  44. device,
  45. ema,
  46. model,
  47. criterion,
  48. dataloader,
  49. optimizer,
  50. lr_scheduler,
  51. warmup_scheduler,
  52. last_opt_step):
  53. epoch_size = len(dataloader)
  54. img_size = args.img_size
  55. t0 = time.time()
  56. nw = epoch_size * args.wp_epoch
  57. accumulate = accumulate = max(1, round(64 / args.batch_size))
  58. # train one epoch
  59. for iter_i, (images, targets) in enumerate(dataloader):
  60. ni = iter_i + epoch * epoch_size
  61. # Warmup
  62. if ni <= nw:
  63. warmup_scheduler.warmup(ni, optimizer)
  64. # visualize train targets
  65. if args.vis_tgt:
  66. vis_data(images, targets)
  67. # to device
  68. images = images.to(device, non_blocking=True).float()
  69. # multi scale
  70. if args.multi_scale and ni % 10 == 0:
  71. images, targets, img_size = rescale_image_targets(
  72. images, targets, model.stride, args.min_box_size)
  73. # inference
  74. outputs = model(images)
  75. # loss
  76. loss_dict = criterion(outputs=outputs, targets=targets)
  77. losses = loss_dict['losses']
  78. # reduce
  79. loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
  80. if args.distributed:
  81. # gradient averaged between devices in DDP mode
  82. losses *= distributed_utils.get_world_size()
  83. # check loss
  84. try:
  85. if torch.isnan(losses):
  86. print('loss is NAN !!')
  87. continue
  88. except:
  89. print(loss_dict)
  90. # backward
  91. losses.backward()
  92. # Optimize
  93. if ni - last_opt_step >= accumulate:
  94. # optimizer.step
  95. optimizer.step()
  96. optimizer.zero_grad()
  97. # EMA
  98. if ema:
  99. ema.update(model)
  100. last_opt_step = ni
  101. # display
  102. if distributed_utils.is_main_process() and iter_i % 10 == 0:
  103. t1 = time.time()
  104. cur_lr = [param_group['lr'] for param_group in optimizer.param_groups]
  105. # basic infor
  106. log = '[Epoch: {}/{}]'.format(epoch+1, total_epochs)
  107. log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
  108. log += '[lr: {:.6f}]'.format(cur_lr[0])
  109. # loss infor
  110. for k in loss_dict_reduced.keys():
  111. if k == 'losses' and args.distributed:
  112. world_size = distributed_utils.get_world_size()
  113. log += '[{}: {:.2f}]'.format(k, loss_dict[k] / world_size)
  114. else:
  115. log += '[{}: {:.2f}]'.format(k, loss_dict[k])
  116. # other infor
  117. log += '[time: {:.2f}]'.format(t1 - t0)
  118. log += '[size: {}]'.format(img_size)
  119. # print log infor
  120. print(log, flush=True)
  121. t0 = time.time()
  122. lr_scheduler.step()
  123. return last_opt_step
  124. def val_one_epoch(args,
  125. model,
  126. evaluator,
  127. optimizer,
  128. epoch,
  129. best_map,
  130. path_to_save):
  131. # check evaluator
  132. if distributed_utils.is_main_process():
  133. if evaluator is None:
  134. print('No evaluator ... save model and go on training.')
  135. print('Saving state, epoch: {}'.format(epoch + 1))
  136. weight_name = '{}_epoch_{}.pth'.format(args.version, epoch + 1)
  137. checkpoint_path = os.path.join(path_to_save, weight_name)
  138. torch.save({'model': model.state_dict(),
  139. 'mAP': -1.,
  140. 'optimizer': optimizer.state_dict(),
  141. 'epoch': epoch,
  142. 'args': args},
  143. checkpoint_path)
  144. else:
  145. print('eval ...')
  146. # set eval mode
  147. model.trainable = False
  148. model.eval()
  149. # evaluate
  150. evaluator.evaluate(model)
  151. cur_map = evaluator.map
  152. if cur_map > best_map:
  153. # update best-map
  154. best_map = cur_map
  155. # save model
  156. print('Saving state, epoch:', epoch + 1)
  157. weight_name = '{}_epoch_{}_{:.2f}.pth'.format(args.version, epoch + 1, best_map*100)
  158. checkpoint_path = os.path.join(path_to_save, weight_name)
  159. torch.save({'model': model.state_dict(),
  160. 'mAP': round(best_map*100, 1),
  161. 'optimizer': optimizer.state_dict(),
  162. 'epoch': epoch,
  163. 'args': args},
  164. checkpoint_path)
  165. # set train mode.
  166. model.trainable = True
  167. model.train()
  168. if args.distributed:
  169. # wait for all processes to synchronize
  170. dist.barrier()
  171. return best_map