util.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import collections
  2. import copy
  3. import datetime
  4. import gc
  5. import time
  6. # import torch
  7. import numpy as np
  8. from util.logconf import logging
  9. log = logging.getLogger(__name__)
  10. # log.setLevel(logging.WARN)
  11. # log.setLevel(logging.INFO)
  12. log.setLevel(logging.DEBUG)
  13. IrcTuple = collections.namedtuple('IrcTuple', ['index', 'row', 'col'])
  14. XyzTuple = collections.namedtuple('XyzTuple', ['x', 'y', 'z'])
  15. def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_tup):
  16. # Note: _cri means Col,Row,Index
  17. if direction_tup == (1, 0, 0, 0, 1, 0, 0, 0, 1):
  18. direction_ary = np.ones((3,))
  19. elif direction_tup == (-1, 0, 0, 0, -1, 0, 0, 0, 1):
  20. direction_ary = np.array((-1, -1, 1))
  21. else:
  22. raise Exception("Unsupported direction_tup: {}".format(direction_tup))
  23. coord_cri = (np.array(coord_xyz) - np.array(origin_xyz)) / np.array(vxSize_xyz)
  24. coord_cri *= direction_ary
  25. return IrcTuple(*list(reversed(coord_cri.tolist())))
  26. def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_tup):
  27. # Note: _cri means Col,Row,Index
  28. coord_cri = np.array(list(reversed(coord_irc)))
  29. coord_xyz = coord_cri * np.array(vxSize_xyz) + np.array(origin_xyz)
  30. return XyzTuple(*coord_xyz.tolist())
  31. def importstr(module_str, from_=None):
  32. """
  33. >>> importstr('os')
  34. <module 'os' from '.../os.pyc'>
  35. >>> importstr('math', 'fabs')
  36. <built-in function fabs>
  37. """
  38. if from_ is None and ':' in module_str:
  39. module_str, from_ = module_str.rsplit(':')
  40. module = __import__(module_str)
  41. for sub_str in module_str.split('.')[1:]:
  42. module = getattr(module, sub_str)
  43. if from_:
  44. try:
  45. return getattr(module, from_)
  46. except:
  47. raise ImportError('{}.{}'.format(module_str, from_))
  48. return module
  49. # class dotdict(dict):
  50. # '''dict where key can be access as attribute d.key -> d[key]'''
  51. # @classmethod
  52. # def deep(cls, dic_obj):
  53. # '''Initialize from dict with deep conversion'''
  54. # return cls(dic_obj).deepConvert()
  55. #
  56. # def __getattr__(self, attr):
  57. # if attr in self:
  58. # return self[attr]
  59. # log.error(sorted(self.keys()))
  60. # raise AttributeError(attr)
  61. # #return self.get(attr, None)
  62. # __setattr__= dict.__setitem__
  63. # __delattr__= dict.__delitem__
  64. #
  65. #
  66. # def __copy__(self):
  67. # return dotdict(self)
  68. #
  69. # def __deepcopy__(self, memo):
  70. # new_dict = dotdict()
  71. # for k, v in self.items():
  72. # new_dict[k] = copy.deepcopy(v, memo)
  73. # return new_dict
  74. #
  75. # # pylint: disable=multiple-statements
  76. # def __getstate__(self): return self.__dict__
  77. # def __setstate__(self, d): self.__dict__.update(d)
  78. #
  79. # def deepConvert(self):
  80. # '''Convert all dicts at all tree levels into dotdict'''
  81. # for k, v in self.items():
  82. # if type(v) is dict: # pylint: disable=unidiomatic-typecheck
  83. # self[k] = dotdict(v)
  84. # self[k].deepConvert()
  85. # try: # try enumerable types
  86. # for m, x in enumerate(v):
  87. # if type(x) is dict: # pylint: disable=unidiomatic-typecheck
  88. # x = dotdict(x)
  89. # x.deepConvert()
  90. # v[m] = x#
  91. # except TypeError:
  92. # pass
  93. # return self
  94. #
  95. # def copy(self):
  96. # # override dict.copy()
  97. # return dotdict(self)
  98. def prhist(ary, prefix_str=None, **kwargs):
  99. if prefix_str is None:
  100. prefix_str = ''
  101. else:
  102. prefix_str += ' '
  103. count_ary, bins_ary = np.histogram(ary, **kwargs)
  104. for i in range(count_ary.shape[0]):
  105. print("{}{:-8.2f}".format(prefix_str, bins_ary[i]), "{:-10}".format(count_ary[i]))
  106. print("{}{:-8.2f}".format(prefix_str, bins_ary[-1]))
  107. # def dumpCuda():
  108. # # small_count = 0
  109. # total_bytes = 0
  110. # size2count_dict = collections.defaultdict(int)
  111. # size2bytes_dict = {}
  112. # for obj in gc.get_objects():
  113. # if isinstance(obj, torch.cuda._CudaBase):
  114. # nbytes = 4
  115. # for n in obj.size():
  116. # nbytes *= n
  117. #
  118. # size2count_dict[tuple([obj.get_device()] + list(obj.size()))] += 1
  119. # size2bytes_dict[tuple([obj.get_device()] + list(obj.size()))] = nbytes
  120. #
  121. # total_bytes += nbytes
  122. #
  123. # # print(small_count, "tensors equal to or less than than 16 bytes")
  124. # for size, count in sorted(size2count_dict.items(), key=lambda sc: (size2bytes_dict[sc[0]] * sc[1], sc[1], sc[0])):
  125. # print('{:4}x'.format(count), '{:10,}'.format(size2bytes_dict[size]), size)
  126. # print('{:10,}'.format(total_bytes), "total bytes")
  127. def enumerateWithEstimate(iter, desc_str, start_ndx=0, print_ndx=4, backoff=2, iter_len=None):
  128. if iter_len is None:
  129. iter_len = len(iter)
  130. assert backoff >= 2
  131. while print_ndx < start_ndx * backoff:
  132. print_ndx *= backoff
  133. log.warning("{} ----/{}, starting".format(
  134. desc_str,
  135. iter_len,
  136. ))
  137. start_ts = time.time()
  138. for (current_ndx, item) in enumerate(iter):
  139. yield (current_ndx, item)
  140. if current_ndx == print_ndx:
  141. duration_sec = ((time.time() - start_ts)
  142. / (current_ndx - start_ndx + 1)
  143. * (iter_len-start_ndx)
  144. )
  145. done_dt = datetime.datetime.fromtimestamp(start_ts + duration_sec)
  146. done_td = datetime.timedelta(seconds=duration_sec)
  147. log.warning("{} {:-4}/{}, done at {}, {}".format(
  148. desc_str,
  149. current_ndx,
  150. iter_len,
  151. str(done_dt).rsplit('.', 1)[0],
  152. str(done_td).rsplit('.', 1)[0],
  153. ))
  154. print_ndx *= backoff
  155. if current_ndx + 1 == start_ndx:
  156. start_ts = time.time()
  157. log.warning("{} ----/{}, done at {}".format(
  158. desc_str,
  159. iter_len,
  160. str(datetime.datetime.now()).rsplit('.', 1)[0],
  161. ))
  162. try:
  163. import matplotlib
  164. matplotlib.use('agg', warn=False)
  165. import matplotlib.pyplot as plt
  166. # matplotlib color maps
  167. cdict = {'red': ((0.0, 1.0, 1.0),
  168. # (0.5, 1.0, 1.0),
  169. (1.0, 1.0, 1.0)),
  170. 'green': ((0.0, 0.0, 0.0),
  171. (0.5, 0.0, 0.0),
  172. (1.0, 0.5, 0.5)),
  173. 'blue': ((0.0, 0.0, 0.0),
  174. # (0.5, 0.5, 0.5),
  175. # (0.75, 0.0, 0.0),
  176. (1.0, 0.0, 0.0)),
  177. 'alpha': ((0.0, 0.0, 0.0),
  178. (0.75, 0.5, 0.5),
  179. (1.0, 0.5, 0.5))}
  180. plt.register_cmap(name='mask', data=cdict)
  181. cdict = {'red': ((0.0, 0.0, 0.0),
  182. (0.25, 1.0, 1.0),
  183. (1.0, 1.0, 1.0)),
  184. 'green': ((0.0, 1.0, 1.0),
  185. (0.25, 1.0, 1.0),
  186. (0.5, 0.0, 0.0),
  187. (1.0, 0.0, 0.0)),
  188. 'blue': ((0.0, 0.0, 0.0),
  189. # (0.5, 0.5, 0.5),
  190. # (0.75, 0.0, 0.0),
  191. (1.0, 0.0, 0.0)),
  192. 'alpha': ((0.0, 0.15, 0.15),
  193. (0.5, 0.3, 0.3),
  194. (0.8, 0.0, 0.0),
  195. (1.0, 0.0, 0.0))}
  196. plt.register_cmap(name='maskinvert', data=cdict)
  197. except ImportError:
  198. pass