vis.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import matplotlib
  2. # matplotlib.use('nbagg')
  3. import matplotlib.pyplot as plt
  4. import torch
  5. import numpy as np
  6. from p2ch13.dsets import Ct, LunaDataset
  7. from p2ch13.model_seg import SegmentationMask, MaskTuple
  8. clim=(-1000.0, 300)
  9. def findPositiveSamples(start_ndx=0, limit=100):
  10. ds = LunaDataset(sortby_str='label_and_size')
  11. positiveSample_list = []
  12. for sample_tup in ds.candidateInfo_list:
  13. if sample_tup.isNodule_bool:
  14. print(len(positiveSample_list), sample_tup)
  15. positiveSample_list.append(sample_tup)
  16. if len(positiveSample_list) >= limit:
  17. break
  18. return positiveSample_list
  19. def showCandidate(series_uid, batch_ndx=None, **kwargs):
  20. ds = LunaDataset(series_uid=series_uid, **kwargs)
  21. pos_list = [i for i, x in enumerate(ds.candidateInfo_list) if x.isNodule_bool]
  22. if batch_ndx is None:
  23. if pos_list:
  24. batch_ndx = pos_list[0]
  25. else:
  26. print("Warning: no positive samples found; using first negative sample.")
  27. batch_ndx = 0
  28. ct = Ct(series_uid)
  29. ct_t, pos_t, series_uid, center_irc = ds[batch_ndx]
  30. ct_a = ct_t[0].numpy()
  31. fig = plt.figure(figsize=(30, 50))
  32. group_list = [
  33. [9, 11, 13],
  34. [15, 16, 17],
  35. [19, 21, 23],
  36. ]
  37. subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
  38. subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
  39. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  40. label.set_fontsize(20)
  41. plt.imshow(ct.hu_a[int(center_irc.index)], clim=clim, cmap='gray')
  42. subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
  43. subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
  44. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  45. label.set_fontsize(20)
  46. plt.imshow(ct.hu_a[:,int(center_irc.row)], clim=clim, cmap='gray')
  47. plt.gca().invert_yaxis()
  48. subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
  49. subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
  50. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  51. label.set_fontsize(20)
  52. plt.imshow(ct.hu_a[:,:,int(center_irc.col)], clim=clim, cmap='gray')
  53. plt.gca().invert_yaxis()
  54. subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
  55. subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
  56. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  57. label.set_fontsize(20)
  58. plt.imshow(ct_a[ct_a.shape[0]//2], clim=clim, cmap='gray')
  59. subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
  60. subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
  61. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  62. label.set_fontsize(20)
  63. plt.imshow(ct_a[:,ct_a.shape[1]//2], clim=clim, cmap='gray')
  64. plt.gca().invert_yaxis()
  65. subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
  66. subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
  67. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  68. label.set_fontsize(20)
  69. plt.imshow(ct_a[:,:,ct_a.shape[2]//2], clim=clim, cmap='gray')
  70. plt.gca().invert_yaxis()
  71. for row, index_list in enumerate(group_list):
  72. for col, index in enumerate(index_list):
  73. subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
  74. subplot.set_title('slice {}'.format(index), fontsize=30)
  75. for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
  76. label.set_fontsize(20)
  77. plt.imshow(ct_a[index], clim=clim, cmap='gray')
  78. print(series_uid, batch_ndx, bool(pos_t[0]), pos_list)
  79. def build2dLungMask(series_uid, center_ndx):
  80. mask_model = SegmentationMask().to('cuda')
  81. ct = Ct(series_uid)
  82. ct_g = torch.from_numpy(ct.hu_a[center_ndx].astype(np.float32)).unsqueeze(0).unsqueeze(0).to('cuda')
  83. pos_g = torch.from_numpy(ct.positive_mask[center_ndx].astype(np.float32)).unsqueeze(0).unsqueeze(0).to('cuda')
  84. input_g = ct_g / 1000
  85. label_g, neg_g, pos_g, lung_mask, mask_dict = mask_model(input_g, pos_g)
  86. mask_tup = MaskTuple(**mask_dict)
  87. return mask_tup