vis.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import matplotlib
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from p2ch4.dsets import Ct, LunaDataset
  5. clim=(0.0, 1.3)
  6. def findMalignantSamples(start_ndx=0, limit=10):
  7. ds = LunaDataset()
  8. malignantSample_list = []
  9. for sample_tup in ds.sample_list:
  10. if sample_tup[2]:
  11. print(len(malignantSample_list), sample_tup)
  12. malignantSample_list.append(sample_tup)
  13. if len(malignantSample_list) >= limit:
  14. break
  15. return malignantSample_list
  16. def showNodule(series_uid, batch_ndx=None, **kwargs):
  17. ds = LunaDataset(series_uid=series_uid, **kwargs)
  18. malignant_list = [i for i, x in enumerate(ds.sample_list) if x[2]]
  19. if batch_ndx is None:
  20. if malignant_list:
  21. batch_ndx = malignant_list[0]
  22. else:
  23. print("Warning: no malignant samples found; using first non-malignant sample.")
  24. batch_ndx = 0
  25. ct = Ct(series_uid)
  26. ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
  27. ct_ary = ct_tensor[0].numpy()
  28. fig = plt.figure(figsize=(15, 25))
  29. group_list = [
  30. #[0,1,2],
  31. [3,4,5],
  32. [6,7,8],
  33. [9,10,11],
  34. #[12,13,14],
  35. #[15]
  36. ]
  37. subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
  38. subplot.set_title('index {}'.format(int(center_irc.index)))
  39. plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
  40. subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
  41. subplot.set_title('row {}'.format(int(center_irc.row)))
  42. plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
  43. subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
  44. subplot.set_title('col {}'.format(int(center_irc.col)))
  45. plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
  46. subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
  47. subplot.set_title('index {}'.format(int(center_irc.index)))
  48. plt.imshow(ct_ary[7], clim=clim, cmap='gray')
  49. subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
  50. subplot.set_title('row {}'.format(int(center_irc.row)))
  51. plt.imshow(ct_ary[:,7], clim=clim, cmap='gray')
  52. subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
  53. subplot.set_title('col {}'.format(int(center_irc.col)))
  54. plt.imshow(ct_ary[:,:,7], clim=clim, cmap='gray')
  55. for row, index_list in enumerate(group_list):
  56. for col, index in enumerate(index_list):
  57. subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
  58. subplot.set_title('slice {}'.format(index))
  59. plt.imshow(ct_ary[index], clim=clim, cmap='gray')
  60. print(series_uid, batch_ndx, bool(malignant_tensor[0][0]), malignant_list, ct.vxSize_xyz)