util.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import collections
  2. import copy
  3. import datetime
  4. import gc
  5. import time
  6. # import torch
  7. import numpy as np
  8. from util.logconf import logging
  9. log = logging.getLogger(__name__)
  10. # log.setLevel(logging.WARN)
  11. # log.setLevel(logging.INFO)
  12. log.setLevel(logging.DEBUG)
  13. IrcTuple = collections.namedtuple('IrcTuple', ['index', 'row', 'col'])
  14. XyzTuple = collections.namedtuple('XyzTuple', ['x', 'y', 'z'])
  15. def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_tup):
  16. if direction_tup == (1, 0, 0, 0, 1, 0, 0, 0, 1):
  17. direction_ary = np.ones((3,))
  18. elif direction_tup == (-1, 0, 0, 0, -1, 0, 0, 0, 1):
  19. direction_ary = np.array((-1, -1, 1))
  20. else:
  21. raise Exception(
  22. "Unsupported direction_tup: {}".format(direction_tup),
  23. )
  24. coord_cri = (
  25. np.array(coord_xyz)
  26. - np.array(origin_xyz)
  27. ) / np.array(vxSize_xyz)
  28. coord_cri *= direction_ary
  29. return IrcTuple(*list(reversed(coord_cri.tolist())))
  30. def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_tup):
  31. coord_cri = np.array(list(reversed(coord_irc)))
  32. if direction_tup == (1, 0, 0, 0, 1, 0, 0, 0, 1):
  33. direction_ary = np.ones((3,))
  34. elif direction_tup == (-1, 0, 0, 0, -1, 0, 0, 0, 1):
  35. direction_ary = np.array((-1, -1, 1))
  36. else:
  37. raise Exception(
  38. "Unsupported direction_tup: {}".format(direction_tup),
  39. )
  40. coord_xyz = coord_cri \
  41. * direction_ary \
  42. * np.array(vxSize_xyz) \
  43. + np.array(origin_xyz)
  44. return XyzTuple(*coord_xyz.tolist())
  45. def importstr(module_str, from_=None):
  46. """
  47. >>> importstr('os')
  48. <module 'os' from '.../os.pyc'>
  49. >>> importstr('math', 'fabs')
  50. <built-in function fabs>
  51. """
  52. if from_ is None and ':' in module_str:
  53. module_str, from_ = module_str.rsplit(':')
  54. module = __import__(module_str)
  55. for sub_str in module_str.split('.')[1:]:
  56. module = getattr(module, sub_str)
  57. if from_:
  58. try:
  59. return getattr(module, from_)
  60. except:
  61. raise ImportError('{}.{}'.format(module_str, from_))
  62. return module
  63. # class dotdict(dict):
  64. # '''dict where key can be access as attribute d.key -> d[key]'''
  65. # @classmethod
  66. # def deep(cls, dic_obj):
  67. # '''Initialize from dict with deep conversion'''
  68. # return cls(dic_obj).deepConvert()
  69. #
  70. # def __getattr__(self, attr):
  71. # if attr in self:
  72. # return self[attr]
  73. # log.error(sorted(self.keys()))
  74. # raise AttributeError(attr)
  75. # #return self.get(attr, None)
  76. # __setattr__= dict.__setitem__
  77. # __delattr__= dict.__delitem__
  78. #
  79. #
  80. # def __copy__(self):
  81. # return dotdict(self)
  82. #
  83. # def __deepcopy__(self, memo):
  84. # new_dict = dotdict()
  85. # for k, v in self.items():
  86. # new_dict[k] = copy.deepcopy(v, memo)
  87. # return new_dict
  88. #
  89. # # pylint: disable=multiple-statements
  90. # def __getstate__(self): return self.__dict__
  91. # def __setstate__(self, d): self.__dict__.update(d)
  92. #
  93. # def deepConvert(self):
  94. # '''Convert all dicts at all tree levels into dotdict'''
  95. # for k, v in self.items():
  96. # if type(v) is dict: # pylint: disable=unidiomatic-typecheck
  97. # self[k] = dotdict(v)
  98. # self[k].deepConvert()
  99. # try: # try enumerable types
  100. # for m, x in enumerate(v):
  101. # if type(x) is dict: # pylint: disable=unidiomatic-typecheck
  102. # x = dotdict(x)
  103. # x.deepConvert()
  104. # v[m] = x#
  105. # except TypeError:
  106. # pass
  107. # return self
  108. #
  109. # def copy(self):
  110. # # override dict.copy()
  111. # return dotdict(self)
  112. def prhist(ary, prefix_str=None, **kwargs):
  113. if prefix_str is None:
  114. prefix_str = ''
  115. else:
  116. prefix_str += ' '
  117. count_ary, bins_ary = np.histogram(ary, **kwargs)
  118. for i in range(count_ary.shape[0]):
  119. print("{}{:-8.2f}".format(prefix_str, bins_ary[i]), "{:-10}".format(count_ary[i]))
  120. print("{}{:-8.2f}".format(prefix_str, bins_ary[-1]))
  121. # def dumpCuda():
  122. # # small_count = 0
  123. # total_bytes = 0
  124. # size2count_dict = collections.defaultdict(int)
  125. # size2bytes_dict = {}
  126. # for obj in gc.get_objects():
  127. # if isinstance(obj, torch.cuda._CudaBase):
  128. # nbytes = 4
  129. # for n in obj.size():
  130. # nbytes *= n
  131. #
  132. # size2count_dict[tuple([obj.get_device()] + list(obj.size()))] += 1
  133. # size2bytes_dict[tuple([obj.get_device()] + list(obj.size()))] = nbytes
  134. #
  135. # total_bytes += nbytes
  136. #
  137. # # print(small_count, "tensors equal to or less than than 16 bytes")
  138. # for size, count in sorted(size2count_dict.items(), key=lambda sc: (size2bytes_dict[sc[0]] * sc[1], sc[1], sc[0])):
  139. # print('{:4}x'.format(count), '{:10,}'.format(size2bytes_dict[size]), size)
  140. # print('{:10,}'.format(total_bytes), "total bytes")
  141. def enumerateWithEstimate(
  142. iter,
  143. desc_str,
  144. start_ndx=0,
  145. print_ndx=4,
  146. backoff=2,
  147. iter_len=None,
  148. ):
  149. """
  150. In terms of behavior, `enumerateWithEstimate` is almost identical
  151. to the standard `enumerate` (the differences are things like how
  152. our function returns a generator, while `enumerate` returns a
  153. specialized `<enumerate object at 0x...>`).
  154. However, the side effects (logging, specifically) are what make the
  155. function interesting.
  156. :param iter: `iter` is the iterable that will be passed into
  157. `enumerate`. Required.
  158. :param desc_str: This is a human-readable string that describes
  159. what the loop is doing. The value is arbitrary, but should be
  160. kept reasonably short. Things like `"epoch 4 training"` or
  161. `"deleting temp files"` or similar would all make sense.
  162. :param start_ndx: This parameter defines how many iterations of the
  163. loop should be skipped before timing actually starts. Skipping
  164. a few iterations can be useful if there are startup costs like
  165. caching that are only paid early on, resulting in a skewed
  166. average when those early iterations dominate the average time
  167. per iteration.
  168. NOTE: Using `start_ndx` to skip some iterations makes the time
  169. spent performing those iterations not be included in the
  170. displayed duration. Please account for this if you use the
  171. displayed duration for anything formal.
  172. This parameter defaults to `0`.
  173. :param print_ndx: determines which loop interation that the timing
  174. logging will start on. The intent is that we don't start
  175. logging until we've given the loop a few iterations to let the
  176. average time-per-iteration a chance to stablize a bit. We
  177. require that `print_ndx` not be less than `start_ndx` times
  178. `backoff`, since `start_ndx` greater than `0` implies that the
  179. early N iterations are unstable from a timing perspective.
  180. `print_ndx` defaults to `4`.
  181. :param backoff: This is used to how many iterations to skip before
  182. logging again. Frequent logging is less interesting later on,
  183. so by default we double the gap between logging messages each
  184. time after the first.
  185. `backoff` defaults to `2`.
  186. :param iter_len: Since we need to know the number of items to
  187. estimate when the loop will finish, that can be provided by
  188. passing in a value for `iter_len`. If a value isn't provided,
  189. then it will be set by using the value of `len(iter)`.
  190. :return:
  191. """
  192. if iter_len is None:
  193. iter_len = len(iter)
  194. assert backoff >= 2
  195. while print_ndx < start_ndx * backoff:
  196. print_ndx *= backoff
  197. log.warning("{} ----/{}, starting".format(
  198. desc_str,
  199. iter_len,
  200. ))
  201. start_ts = time.time()
  202. for (current_ndx, item) in enumerate(iter):
  203. yield (current_ndx, item)
  204. if current_ndx == print_ndx:
  205. # ... <1>
  206. duration_sec = ((time.time() - start_ts)
  207. / (current_ndx - start_ndx + 1)
  208. * (iter_len-start_ndx)
  209. )
  210. done_dt = datetime.datetime.fromtimestamp(start_ts + duration_sec)
  211. done_td = datetime.timedelta(seconds=duration_sec)
  212. log.info("{} {:-4}/{}, done at {}, {}".format(
  213. desc_str,
  214. current_ndx,
  215. iter_len,
  216. str(done_dt).rsplit('.', 1)[0],
  217. str(done_td).rsplit('.', 1)[0],
  218. ))
  219. print_ndx *= backoff
  220. if current_ndx + 1 == start_ndx:
  221. start_ts = time.time()
  222. log.warning("{} ----/{}, done at {}".format(
  223. desc_str,
  224. iter_len,
  225. str(datetime.datetime.now()).rsplit('.', 1)[0],
  226. ))
  227. #
  228. # try:
  229. # import matplotlib
  230. # matplotlib.use('agg', warn=False)
  231. #
  232. # import matplotlib.pyplot as plt
  233. # # matplotlib color maps
  234. # cdict = {'red': ((0.0, 1.0, 1.0),
  235. # # (0.5, 1.0, 1.0),
  236. # (1.0, 1.0, 1.0)),
  237. #
  238. # 'green': ((0.0, 0.0, 0.0),
  239. # (0.5, 0.0, 0.0),
  240. # (1.0, 0.5, 0.5)),
  241. #
  242. # 'blue': ((0.0, 0.0, 0.0),
  243. # # (0.5, 0.5, 0.5),
  244. # # (0.75, 0.0, 0.0),
  245. # (1.0, 0.0, 0.0)),
  246. #
  247. # 'alpha': ((0.0, 0.0, 0.0),
  248. # (0.75, 0.5, 0.5),
  249. # (1.0, 0.5, 0.5))}
  250. #
  251. # plt.register_cmap(name='mask', data=cdict)
  252. #
  253. # cdict = {'red': ((0.0, 0.0, 0.0),
  254. # (0.25, 1.0, 1.0),
  255. # (1.0, 1.0, 1.0)),
  256. #
  257. # 'green': ((0.0, 1.0, 1.0),
  258. # (0.25, 1.0, 1.0),
  259. # (0.5, 0.0, 0.0),
  260. # (1.0, 0.0, 0.0)),
  261. #
  262. # 'blue': ((0.0, 0.0, 0.0),
  263. # # (0.5, 0.5, 0.5),
  264. # # (0.75, 0.0, 0.0),
  265. # (1.0, 0.0, 0.0)),
  266. #
  267. # 'alpha': ((0.0, 0.15, 0.15),
  268. # (0.5, 0.3, 0.3),
  269. # (0.8, 0.0, 0.0),
  270. # (1.0, 0.0, 0.0))}
  271. #
  272. # plt.register_cmap(name='maskinvert', data=cdict)
  273. # except ImportError:
  274. # pass