util.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import collections
  2. import copy
  3. import datetime
  4. import gc
  5. import time
  6. # import torch
  7. import numpy as np
  8. from util.logconf import logging
  9. log = logging.getLogger(__name__)
  10. # log.setLevel(logging.WARN)
  11. # log.setLevel(logging.INFO)
  12. log.setLevel(logging.DEBUG)
  13. IrcTuple = collections.namedtuple('IrcTuple', ['index', 'row', 'col'])
  14. XyzTuple = collections.namedtuple('XyzTuple', ['x', 'y', 'z'])
  15. def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_tup):
  16. # Note: _cri means Col,Row,Index
  17. if direction_tup == (1, 0, 0, 0, 1, 0, 0, 0, 1):
  18. direction_ary = np.ones((3,))
  19. elif direction_tup == (-1, 0, 0, 0, -1, 0, 0, 0, 1):
  20. direction_ary = np.array((-1, -1, 1))
  21. else:
  22. raise Exception("Unsupported direction_tup: {}".format(direction_tup))
  23. coord_cri = (np.array(coord_xyz) - np.array(origin_xyz)) / np.array(vxSize_xyz)
  24. coord_cri *= direction_ary
  25. return IrcTuple(*list(reversed(coord_cri.tolist())))
  26. def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_tup):
  27. # Note: _cri means Col,Row,Index
  28. coord_cri = np.array(list(reversed(coord_irc)))
  29. if direction_tup == (1, 0, 0, 0, 1, 0, 0, 0, 1):
  30. direction_ary = np.ones((3,))
  31. elif direction_tup == (-1, 0, 0, 0, -1, 0, 0, 0, 1):
  32. direction_ary = np.array((-1, -1, 1))
  33. else:
  34. raise Exception("Unsupported direction_tup: {}".format(direction_tup))
  35. coord_xyz = coord_cri * direction_ary * np.array(vxSize_xyz) + np.array(origin_xyz)
  36. return XyzTuple(*coord_xyz.tolist())
  37. def importstr(module_str, from_=None):
  38. """
  39. >>> importstr('os')
  40. <module 'os' from '.../os.pyc'>
  41. >>> importstr('math', 'fabs')
  42. <built-in function fabs>
  43. """
  44. if from_ is None and ':' in module_str:
  45. module_str, from_ = module_str.rsplit(':')
  46. module = __import__(module_str)
  47. for sub_str in module_str.split('.')[1:]:
  48. module = getattr(module, sub_str)
  49. if from_:
  50. try:
  51. return getattr(module, from_)
  52. except:
  53. raise ImportError('{}.{}'.format(module_str, from_))
  54. return module
  55. # class dotdict(dict):
  56. # '''dict where key can be access as attribute d.key -> d[key]'''
  57. # @classmethod
  58. # def deep(cls, dic_obj):
  59. # '''Initialize from dict with deep conversion'''
  60. # return cls(dic_obj).deepConvert()
  61. #
  62. # def __getattr__(self, attr):
  63. # if attr in self:
  64. # return self[attr]
  65. # log.error(sorted(self.keys()))
  66. # raise AttributeError(attr)
  67. # #return self.get(attr, None)
  68. # __setattr__= dict.__setitem__
  69. # __delattr__= dict.__delitem__
  70. #
  71. #
  72. # def __copy__(self):
  73. # return dotdict(self)
  74. #
  75. # def __deepcopy__(self, memo):
  76. # new_dict = dotdict()
  77. # for k, v in self.items():
  78. # new_dict[k] = copy.deepcopy(v, memo)
  79. # return new_dict
  80. #
  81. # # pylint: disable=multiple-statements
  82. # def __getstate__(self): return self.__dict__
  83. # def __setstate__(self, d): self.__dict__.update(d)
  84. #
  85. # def deepConvert(self):
  86. # '''Convert all dicts at all tree levels into dotdict'''
  87. # for k, v in self.items():
  88. # if type(v) is dict: # pylint: disable=unidiomatic-typecheck
  89. # self[k] = dotdict(v)
  90. # self[k].deepConvert()
  91. # try: # try enumerable types
  92. # for m, x in enumerate(v):
  93. # if type(x) is dict: # pylint: disable=unidiomatic-typecheck
  94. # x = dotdict(x)
  95. # x.deepConvert()
  96. # v[m] = x#
  97. # except TypeError:
  98. # pass
  99. # return self
  100. #
  101. # def copy(self):
  102. # # override dict.copy()
  103. # return dotdict(self)
  104. def prhist(ary, prefix_str=None, **kwargs):
  105. if prefix_str is None:
  106. prefix_str = ''
  107. else:
  108. prefix_str += ' '
  109. count_ary, bins_ary = np.histogram(ary, **kwargs)
  110. for i in range(count_ary.shape[0]):
  111. print("{}{:-8.2f}".format(prefix_str, bins_ary[i]), "{:-10}".format(count_ary[i]))
  112. print("{}{:-8.2f}".format(prefix_str, bins_ary[-1]))
  113. # def dumpCuda():
  114. # # small_count = 0
  115. # total_bytes = 0
  116. # size2count_dict = collections.defaultdict(int)
  117. # size2bytes_dict = {}
  118. # for obj in gc.get_objects():
  119. # if isinstance(obj, torch.cuda._CudaBase):
  120. # nbytes = 4
  121. # for n in obj.size():
  122. # nbytes *= n
  123. #
  124. # size2count_dict[tuple([obj.get_device()] + list(obj.size()))] += 1
  125. # size2bytes_dict[tuple([obj.get_device()] + list(obj.size()))] = nbytes
  126. #
  127. # total_bytes += nbytes
  128. #
  129. # # print(small_count, "tensors equal to or less than than 16 bytes")
  130. # for size, count in sorted(size2count_dict.items(), key=lambda sc: (size2bytes_dict[sc[0]] * sc[1], sc[1], sc[0])):
  131. # print('{:4}x'.format(count), '{:10,}'.format(size2bytes_dict[size]), size)
  132. # print('{:10,}'.format(total_bytes), "total bytes")
  133. def enumerateWithEstimate(iter, desc_str, start_ndx=0, print_ndx=4, backoff=2, iter_len=None):
  134. """
  135. :param iter: `iter` is the iterable that will be passed into `enumerate`. Required.
  136. :param desc_str: This is a human-readable string that describes what the loop is doing.
  137. The value is arbitrary, but should be kept reasonably short.
  138. Things like `"epoch 4 training"` or `"deleting temp files"` or similar
  139. would all make sense.
  140. :param start_ndx:
  141. :param print_ndx:
  142. :param backoff:
  143. :param iter_len: Since we need to know the number of items to estimate when the loop will finish,
  144. that can be provided by passing in a value for `iter_len`.
  145. If a value isn't provided, then it will be set by using the value of `len(iter)`.
  146. :return:
  147. ==== Required argument: `iter` and optionally `iter_len`
  148. These two are pretty simple.
  149. ==== Required argument: `desc_str`
  150. ==== Optional argument: `start_ndx`
  151. This parameter defines how many iterations of the loop should be skipped
  152. before timing actually starts.
  153. Skipping a few iterations can be useful if there are startup costs
  154. like caching that are only paid early on,
  155. resulting in a skewed average
  156. when those early iterations dominate the average time per iteration.
  157. NOTE: Using `start_ndx` to skip some iterations makes the time spent
  158. performing those iterations not be included
  159. in the displayed duration.
  160. Please account for this if you use the displayed duration for anything formal.
  161. This parameter defaults to `0`.
  162. ==== Optional arguments: `print_ndx` and `backoff`
  163. `print_ndx` determines which loop interation that the timing logging will start on,
  164. and `backoff` is used to how many iterations to skip before logging again.
  165. The intent is that we don't start logging until we've given the loop
  166. a few iterations to let the average time-per-iteration a chance to stablize a bit.
  167. We require that `print_ndx` not be less than `start_ndx` times `backoff`,
  168. since `start_ndx` greater than `0` implies that the early N iterations
  169. are unstable from a timing perspective.
  170. Frequent logging is less interesting later on,
  171. so by default we double the gap between logging messages each time after the first.
  172. `print_ndx` defaults to `4` and `backoff` defaults to `2`.
  173. """
  174. if iter_len is None:
  175. iter_len = len(iter)
  176. assert backoff >= 2
  177. while print_ndx < start_ndx * backoff:
  178. print_ndx *= backoff
  179. log.warning("{} ----/{}, starting".format(
  180. desc_str,
  181. iter_len,
  182. ))
  183. start_ts = time.time()
  184. for (current_ndx, item) in enumerate(iter):
  185. yield (current_ndx, item)
  186. if current_ndx == print_ndx:
  187. # ... <1>
  188. duration_sec = ((time.time() - start_ts)
  189. / (current_ndx - start_ndx + 1)
  190. * (iter_len-start_ndx)
  191. )
  192. done_dt = datetime.datetime.fromtimestamp(start_ts + duration_sec)
  193. done_td = datetime.timedelta(seconds=duration_sec)
  194. log.info("{} {:-4}/{}, done at {}, {}".format(
  195. desc_str,
  196. current_ndx,
  197. iter_len,
  198. str(done_dt).rsplit('.', 1)[0],
  199. str(done_td).rsplit('.', 1)[0],
  200. ))
  201. print_ndx *= backoff
  202. if current_ndx + 1 == start_ndx:
  203. start_ts = time.time()
  204. log.warning("{} ----/{}, done at {}".format(
  205. desc_str,
  206. iter_len,
  207. str(datetime.datetime.now()).rsplit('.', 1)[0],
  208. ))
  209. try:
  210. import matplotlib
  211. matplotlib.use('agg', warn=False)
  212. import matplotlib.pyplot as plt
  213. # matplotlib color maps
  214. cdict = {'red': ((0.0, 1.0, 1.0),
  215. # (0.5, 1.0, 1.0),
  216. (1.0, 1.0, 1.0)),
  217. 'green': ((0.0, 0.0, 0.0),
  218. (0.5, 0.0, 0.0),
  219. (1.0, 0.5, 0.5)),
  220. 'blue': ((0.0, 0.0, 0.0),
  221. # (0.5, 0.5, 0.5),
  222. # (0.75, 0.0, 0.0),
  223. (1.0, 0.0, 0.0)),
  224. 'alpha': ((0.0, 0.0, 0.0),
  225. (0.75, 0.5, 0.5),
  226. (1.0, 0.5, 0.5))}
  227. plt.register_cmap(name='mask', data=cdict)
  228. cdict = {'red': ((0.0, 0.0, 0.0),
  229. (0.25, 1.0, 1.0),
  230. (1.0, 1.0, 1.0)),
  231. 'green': ((0.0, 1.0, 1.0),
  232. (0.25, 1.0, 1.0),
  233. (0.5, 0.0, 0.0),
  234. (1.0, 0.0, 0.0)),
  235. 'blue': ((0.0, 0.0, 0.0),
  236. # (0.5, 0.5, 0.5),
  237. # (0.75, 0.0, 0.0),
  238. (1.0, 0.0, 0.0)),
  239. 'alpha': ((0.0, 0.15, 0.15),
  240. (0.5, 0.3, 0.3),
  241. (0.8, 0.0, 0.0),
  242. (1.0, 0.0, 0.0))}
  243. plt.register_cmap(name='maskinvert', data=cdict)
  244. except ImportError:
  245. pass