vis.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import matplotlib
  2. matplotlib.use('nbagg')
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from p2ch10.dsets import Ct, LunaDataset
  6. clim=(-1000.0, 300)
  7. def findMalignantSamples(start_ndx=0, limit=100):
  8. ds = LunaDataset()
  9. malignantSample_list = []
  10. for sample_tup in ds.noduleInfo_list:
  11. if sample_tup.isMalignant_bool:
  12. print(len(malignantSample_list), sample_tup)
  13. malignantSample_list.append(sample_tup)
  14. if len(malignantSample_list) >= limit:
  15. break
  16. return malignantSample_list
  17. def showNodule(series_uid, batch_ndx=None, **kwargs):
  18. ds = LunaDataset(series_uid=series_uid, **kwargs)
  19. malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x.isMalignant_bool]
  20. if batch_ndx is None:
  21. if malignant_list:
  22. batch_ndx = malignant_list[0]
  23. else:
  24. print("Warning: no malignant samples found; using first non-malignant sample.")
  25. batch_ndx = 0
  26. ct = Ct(series_uid)
  27. ct_t, malignant_t, series_uid, center_irc = ds[batch_ndx]
  28. ct_a = ct_t[0].numpy()
  29. fig = plt.figure(figsize=(30, 50))
  30. group_list = [
  31. [9, 11, 13],
  32. [15, 16, 17],
  33. [19, 21, 23],
  34. ]
  35. subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
  36. subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
  37. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  38. label.set_fontsize(20)
  39. plt.imshow(ct.hu_a[int(center_irc.index)], clim=clim, cmap='gray')
  40. subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
  41. subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
  42. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  43. label.set_fontsize(20)
  44. plt.imshow(ct.hu_a[:,int(center_irc.row)], clim=clim, cmap='gray')
  45. plt.gca().invert_yaxis()
  46. subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
  47. subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
  48. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  49. label.set_fontsize(20)
  50. plt.imshow(ct.hu_a[:,:,int(center_irc.col)], clim=clim, cmap='gray')
  51. plt.gca().invert_yaxis()
  52. subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
  53. subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
  54. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  55. label.set_fontsize(20)
  56. plt.imshow(ct_a[ct_a.shape[0]//2], clim=clim, cmap='gray')
  57. subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
  58. subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
  59. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  60. label.set_fontsize(20)
  61. plt.imshow(ct_a[:,ct_a.shape[1]//2], clim=clim, cmap='gray')
  62. plt.gca().invert_yaxis()
  63. subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
  64. subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
  65. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  66. label.set_fontsize(20)
  67. plt.imshow(ct_a[:,:,ct_a.shape[2]//2], clim=clim, cmap='gray')
  68. plt.gca().invert_yaxis()
  69. for row, index_list in enumerate(group_list):
  70. for col, index in enumerate(index_list):
  71. subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
  72. subplot.set_title('slice {}'.format(index), fontsize=30)
  73. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  74. label.set_fontsize(20)
  75. plt.imshow(ct_a[index], clim=clim, cmap='gray')
  76. print(series_uid, batch_ndx, bool(malignant_t[0]), malignant_list)