misc.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import time
  2. import numpy as np
  3. import random
  4. import datetime
  5. from collections import defaultdict, deque
  6. from pathlib import Path
  7. import torch
  8. import torch.nn as nn
  9. import torch.distributed as dist
  10. from .distributed_utils import get_world_size, is_main_process, is_dist_avail_and_initialized
  11. # ---------------------- Common functions ----------------------
  12. def all_reduce_mean(x):
  13. world_size = get_world_size()
  14. if world_size > 1:
  15. x_reduce = torch.tensor(x).cuda()
  16. dist.all_reduce(x_reduce)
  17. x_reduce /= world_size
  18. return x_reduce.item()
  19. else:
  20. return x
  21. def print_rank_0(msg, rank=None):
  22. if rank is not None and rank <= 0:
  23. print(msg)
  24. elif is_main_process():
  25. print(msg)
  26. def setup_seed(seed=42):
  27. torch.manual_seed(seed)
  28. torch.cuda.manual_seed_all(seed)
  29. np.random.seed(seed)
  30. random.seed(seed)
  31. torch.backends.cudnn.deterministic = True
  32. def is_parallel(model):
  33. # Returns True if model is of type DP or DDP
  34. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  35. def accuracy(output, target, topk=(1,)):
  36. """Computes the accuracy over the k top predictions for the specified values of k"""
  37. with torch.no_grad():
  38. maxk = max(topk)
  39. batch_size = target.size(0)
  40. _, pred = output.topk(maxk, 1, True, True)
  41. pred = pred.t()
  42. correct = pred.eq(target.reshape(1, -1).expand_as(pred))
  43. res = []
  44. for k in topk:
  45. correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
  46. res.append(correct_k.mul_(100.0 / batch_size))
  47. return res
  48. class SmoothedValue(object):
  49. """Track a series of values and provide access to smoothed values over a
  50. window or the global series average.
  51. """
  52. def __init__(self, window_size=20, fmt=None):
  53. if fmt is None:
  54. fmt = "{median:.4f} ({global_avg:.4f})"
  55. self.deque = deque(maxlen=window_size)
  56. self.total = 0.0
  57. self.count = 0
  58. self.fmt = fmt
  59. def update(self, value, n=1):
  60. self.deque.append(value)
  61. self.count += n
  62. self.total += value * n
  63. def synchronize_between_processes(self):
  64. """
  65. Warning: does not synchronize the deque!
  66. """
  67. if not is_dist_avail_and_initialized():
  68. return
  69. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  70. dist.barrier()
  71. dist.all_reduce(t)
  72. t = t.tolist()
  73. self.count = int(t[0])
  74. self.total = t[1]
  75. @property
  76. def median(self):
  77. d = torch.tensor(list(self.deque))
  78. return d.median().item()
  79. @property
  80. def avg(self):
  81. d = torch.tensor(list(self.deque), dtype=torch.float32)
  82. return d.mean().item()
  83. @property
  84. def global_avg(self):
  85. return self.total / self.count
  86. @property
  87. def max(self):
  88. return max(self.deque)
  89. @property
  90. def value(self):
  91. return self.deque[-1]
  92. def __str__(self):
  93. return self.fmt.format(
  94. median=self.median,
  95. avg=self.avg,
  96. global_avg=self.global_avg,
  97. max=self.max,
  98. value=self.value)
  99. class MetricLogger(object):
  100. def __init__(self, delimiter="\t"):
  101. self.meters = defaultdict(SmoothedValue)
  102. self.delimiter = delimiter
  103. def update(self, **kwargs):
  104. for k, v in kwargs.items():
  105. if v is None:
  106. continue
  107. if isinstance(v, torch.Tensor):
  108. v = v.item()
  109. assert isinstance(v, (float, int))
  110. self.meters[k].update(v)
  111. def __getattr__(self, attr):
  112. if attr in self.meters:
  113. return self.meters[attr]
  114. if attr in self.__dict__:
  115. return self.__dict__[attr]
  116. raise AttributeError("'{}' object has no attribute '{}'".format(
  117. type(self).__name__, attr))
  118. def __str__(self):
  119. loss_str = []
  120. for name, meter in self.meters.items():
  121. loss_str.append(
  122. "{}: {}".format(name, str(meter))
  123. )
  124. return self.delimiter.join(loss_str)
  125. def synchronize_between_processes(self):
  126. for meter in self.meters.values():
  127. meter.synchronize_between_processes()
  128. def add_meter(self, name, meter):
  129. self.meters[name] = meter
  130. def log_every(self, iterable, print_freq, header=None):
  131. i = 0
  132. if not header:
  133. header = ''
  134. start_time = time.time()
  135. end = time.time()
  136. iter_time = SmoothedValue(fmt='{avg:.4f}')
  137. data_time = SmoothedValue(fmt='{avg:.4f}')
  138. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  139. log_msg = [
  140. header,
  141. '[{0' + space_fmt + '}/{1}]',
  142. 'eta: {eta}',
  143. '{meters}',
  144. 'time: {time}',
  145. 'data: {data}'
  146. ]
  147. if torch.cuda.is_available():
  148. log_msg.append('max mem: {memory:.0f}')
  149. log_msg = self.delimiter.join(log_msg)
  150. MB = 1024.0 * 1024.0
  151. for obj in iterable:
  152. data_time.update(time.time() - end)
  153. yield obj
  154. iter_time.update(time.time() - end)
  155. if i % print_freq == 0 or i == len(iterable) - 1:
  156. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  157. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  158. if torch.cuda.is_available():
  159. print(log_msg.format(
  160. i, len(iterable), eta=eta_string,
  161. meters=str(self),
  162. time=str(iter_time), data=str(data_time),
  163. memory=torch.cuda.max_memory_allocated() / MB))
  164. else:
  165. print(log_msg.format(
  166. i, len(iterable), eta=eta_string,
  167. meters=str(self),
  168. time=str(iter_time), data=str(data_time)))
  169. i += 1
  170. end = time.time()
  171. total_time = time.time() - start_time
  172. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  173. print('{} Total time: {} ({:.4f} s / it)'.format(
  174. header, total_time_str, total_time / len(iterable)))
  175. # ---------------------- Optimize functions ----------------------
  176. def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
  177. if isinstance(parameters, torch.Tensor):
  178. parameters = [parameters]
  179. parameters = [p for p in parameters if p.grad is not None]
  180. norm_type = float(norm_type)
  181. if len(parameters) == 0:
  182. return torch.tensor(0.)
  183. device = parameters[0].grad.device
  184. total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device)
  185. for p in parameters]),
  186. norm_type)
  187. return total_norm
  188. class NativeScalerWithGradNormCount:
  189. state_dict_key = "amp_scaler"
  190. def __init__(self):
  191. self._scaler = torch.cuda.amp.GradScaler()
  192. def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
  193. self._scaler.scale(loss).backward()
  194. if update_grad:
  195. if clip_grad is not None:
  196. assert parameters is not None
  197. self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
  198. norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
  199. else:
  200. self._scaler.unscale_(optimizer)
  201. norm = get_grad_norm_(parameters)
  202. self._scaler.step(optimizer)
  203. self._scaler.update()
  204. else:
  205. norm = None
  206. return norm
  207. def state_dict(self):
  208. return self._scaler.state_dict()
  209. def load_state_dict(self, state_dict):
  210. self._scaler.load_state_dict(state_dict)
  211. # ---------------------- Model functions ----------------------
  212. def load_model(args, model_without_ddp, optimizer, lr_scheduler, loss_scaler):
  213. if args.resume and args.resume.lower() != 'none':
  214. print("=================== Load checkpoint ===================")
  215. if args.resume.startswith('https'):
  216. checkpoint = torch.hub.load_state_dict_from_url(
  217. args.resume, map_location='cpu', check_hash=True)
  218. else:
  219. checkpoint = torch.load(args.resume, map_location='cpu')
  220. model_without_ddp.load_state_dict(checkpoint['model'])
  221. print("Resume checkpoint %s" % args.resume)
  222. if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
  223. print('- Load optimizer from the checkpoint. ')
  224. optimizer.load_state_dict(checkpoint['optimizer'])
  225. args.start_epoch = checkpoint['epoch'] + 1
  226. if 'scaler' in checkpoint:
  227. loss_scaler.load_state_dict(checkpoint['scaler'])
  228. if 'lr_scheduler' in checkpoint:
  229. print('- Load lr scheduler from the checkpoint. ')
  230. lr_scheduler.load_state_dict(checkpoint.pop("lr_scheduler"))
  231. def save_model(args, epoch, model, model_without_ddp, optimizer, lr_scheduler, loss_scaler, acc1=None):
  232. output_dir = Path(args.output_dir)
  233. epoch_name = str(epoch)
  234. if loss_scaler is not None:
  235. if acc1 is not None:
  236. checkpoint_paths = [output_dir / ('checkpoint-{}-Acc1-{:.2f}.pth'.format(epoch_name, acc1))]
  237. else:
  238. checkpoint_paths = [output_dir / ('checkpoint-{}.pth'.format(epoch_name))]
  239. for checkpoint_path in checkpoint_paths:
  240. to_save = {
  241. 'model': model_without_ddp.state_dict(),
  242. 'optimizer': optimizer.state_dict(),
  243. 'lr_scheduler': lr_scheduler.state_dict(),
  244. 'epoch': epoch,
  245. 'scaler': loss_scaler.state_dict(),
  246. 'args': args,
  247. }
  248. torch.save(to_save, checkpoint_path)
  249. else:
  250. client_state = {'epoch': epoch}
  251. model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)