plot_utils.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. """
  2. Plotting utilities to visualize training logs.
  3. """
  4. import torch
  5. import pandas as pd
  6. import numpy as np
  7. import seaborn as sns
  8. import matplotlib.pyplot as plt
  9. from pathlib import Path, PurePath
  10. def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
  11. '''
  12. Function to plot specific fields from training log(s). Plots both training and test results.
  13. :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
  14. - fields = which results to plot from each log file - plots both training and test for each field.
  15. - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
  16. - log_name = optional, name of log file if different than default 'log.txt'.
  17. :: Outputs - matplotlib plots of results in fields, color coded for each log file.
  18. - solid lines are training results, dashed lines are test results.
  19. '''
  20. func_name = "plot_utils.py::plot_logs"
  21. # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
  22. # convert single Path to list to avoid 'not iterable' error
  23. if not isinstance(logs, list):
  24. if isinstance(logs, PurePath):
  25. logs = [logs]
  26. print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
  27. else:
  28. raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
  29. Expect list[Path] or single Path obj, received {type(logs)}")
  30. # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
  31. for i, dir in enumerate(logs):
  32. if not isinstance(dir, PurePath):
  33. raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
  34. if not dir.exists():
  35. raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
  36. # verify log_name exists
  37. fn = Path(dir / log_name)
  38. if not fn.exists():
  39. print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
  40. print(f"--> full path of missing log file: {fn}")
  41. return
  42. # load log file(s) and plot
  43. dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
  44. fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
  45. for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
  46. for j, field in enumerate(fields):
  47. if field == 'mAP':
  48. coco_eval = pd.DataFrame(
  49. np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
  50. ).ewm(com=ewm_col).mean()
  51. axs[j].plot(coco_eval, c=color)
  52. else:
  53. df.interpolate().ewm(com=ewm_col).mean().plot(
  54. y=[f'train_{field}', f'test_{field}'],
  55. ax=axs[j],
  56. color=[color] * 2,
  57. style=['-', '--']
  58. )
  59. for ax, field in zip(axs, fields):
  60. ax.legend([Path(p).name for p in logs])
  61. ax.set_title(field)
  62. def plot_precision_recall(files, naming_scheme='iter'):
  63. if naming_scheme == 'exp_id':
  64. # name becomes exp_id
  65. names = [f.parts[-3] for f in files]
  66. elif naming_scheme == 'iter':
  67. names = [f.stem for f in files]
  68. else:
  69. raise ValueError(f'not supported {naming_scheme}')
  70. fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
  71. for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
  72. data = torch.load(f)
  73. # precision is n_iou, n_points, n_cat, n_area, max_det
  74. precision = data['precision']
  75. recall = data['params'].recThrs
  76. scores = data['scores']
  77. # take precision for all classes, all areas and 100 detections
  78. precision = precision[0, :, :, 0, -1].mean(1)
  79. scores = scores[0, :, :, 0, -1].mean(1)
  80. prec = precision.mean()
  81. rec = data['recall'][0, :, 0, -1].mean()
  82. print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
  83. f'score={scores.mean():0.3f}, ' +
  84. f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
  85. )
  86. axs[0].plot(recall, precision, c=color)
  87. axs[1].plot(recall, scores, c=color)
  88. axs[0].set_title('Precision / Recall')
  89. axs[0].legend(names)
  90. axs[1].set_title('Scores / Recall')
  91. axs[1].legend(names)
  92. return fig, axs