test_affine.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. import math
  2. import random
  3. import numpy as np
  4. import scipy.ndimage
  5. import torch
  6. import pytest
  7. from util.logconf import logging
  8. log = logging.getLogger(__name__)
  9. # log.setLevel(logging.WARN)
  10. # log.setLevel(logging.INFO)
  11. log.setLevel(logging.DEBUG)
  12. from .affine import affine_grid_generator
  13. if torch.cuda.is_available():
  14. @pytest.fixture(params=['cpu', 'cuda'])
  15. def device(request):
  16. return request.param
  17. else:
  18. @pytest.fixture(params=['cpu'])
  19. def device(request):
  20. return request.param
  21. # @pytest.fixture(params=[0., 0.25])
  22. @pytest.fixture(params=[0.0, 0.5, 0.25, 0.125, 'random'])
  23. def angle_rad(request):
  24. if request.param == 'random':
  25. return random.random() * math.pi * 2
  26. return request.param * math.pi * 2
  27. @pytest.fixture(params=[(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), 'random'])
  28. def axis_vector(request):
  29. if request.param == 'random':
  30. t = (random.random(), random.random(), random.random())
  31. l = sum(x**2 for x in t)**0.5
  32. return tuple(x/l for x in t)
  33. return request.param
  34. @pytest.fixture(params=[torch.nn.functional.affine_grid, affine_grid_generator])
  35. def affine_func2d(request):
  36. return request.param
  37. @pytest.fixture(params=[affine_grid_generator])
  38. def affine_func3d(request):
  39. return request.param
  40. # @pytest.fixture(params=[[1, 1, 3, 5], [1, 1, 3, 3]])
  41. @pytest.fixture(params=[[1, 1, 3, 5], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 3, 4]])
  42. def input_size2d(request):
  43. return request.param
  44. # @pytest.fixture(params=[[1, 1, 5, 3], [1, 1, 3, 5], [1, 1, 5, 5]])
  45. @pytest.fixture(params=[[1, 1, 5, 3], [1, 1, 3, 5], [1, 1, 4, 3], [1, 1, 5, 5], [1, 1, 6, 6]])
  46. def output_size2d(request):
  47. return request.param
  48. @pytest.fixture(params=[[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 6, 6], ])
  49. def input_size2dsq(request):
  50. return request.param
  51. @pytest.fixture(params=[[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 5, 5], [1, 1, 6, 6], ])
  52. def output_size2dsq(request):
  53. return request.param
  54. # @pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 2, 3, 4]])
  55. @pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 2, 3, 4], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 3, 4, 5]])
  56. def input_size3d(request):
  57. return request.param
  58. @pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 6, 6, 6]])
  59. def input_size3dsq(request):
  60. return request.param
  61. @pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]])
  62. def output_size3dsq(request):
  63. return request.param
  64. # @pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 3, 4, 5]])
  65. @pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 3, 4, 5], [1, 1, 4, 3, 2], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]])
  66. def output_size3d(request):
  67. return request.param
  68. def _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad):
  69. print("_buildEquivalentTransforms2d", device, input_size, output_size, angle_rad * 180 / math.pi)
  70. input_center = [(x-1)/2 for x in input_size]
  71. output_center = [(x-1)/2 for x in output_size]
  72. s = math.sin(angle_rad)
  73. c = math.cos(angle_rad)
  74. intrans_ary = np.array([
  75. [1, 0, input_center[2]],
  76. [0, 1, input_center[3]],
  77. [0, 0, 1],
  78. ], dtype=np.float64)
  79. inscale_ary = np.array([
  80. [input_center[2], 0, 0],
  81. [0, input_center[3], 0],
  82. [0, 0, 1],
  83. ], dtype=np.float64)
  84. rotation_ary = np.array([
  85. [c, -s, 0],
  86. [s, c, 0],
  87. [0, 0, 1],
  88. ], dtype=np.float64)
  89. outscale_ary = np.array([
  90. [1/output_center[2], 0, 0],
  91. [0, 1/output_center[3], 0],
  92. [0, 0, 1],
  93. ], dtype=np.float64)
  94. outtrans_ary = np.array([
  95. [1, 0, -output_center[2]],
  96. [0, 1, -output_center[3]],
  97. [0, 0, 1],
  98. ], dtype=np.float64)
  99. reorder_ary = np.array([
  100. [0, 1, 0],
  101. [1, 0, 0],
  102. [0, 0, 1],
  103. ], dtype=np.float64)
  104. transform_ary = intrans_ary @ inscale_ary @ rotation_ary.T @ outscale_ary @ outtrans_ary
  105. grid_ary = reorder_ary @ rotation_ary.T @ outscale_ary @ outtrans_ary
  106. transform_t = torch.from_numpy((rotation_ary)).to(device, torch.float32)
  107. transform_t = transform_t[:2].unsqueeze(0)
  108. print('transform_t', transform_t.size(), transform_t.dtype, transform_t.device)
  109. print(transform_t)
  110. print('outtrans_ary', outtrans_ary.shape, outtrans_ary.dtype)
  111. print(outtrans_ary.round(3))
  112. print('outscale_ary', outscale_ary.shape, outscale_ary.dtype)
  113. print(outscale_ary.round(3))
  114. print('rotation_ary', rotation_ary.shape, rotation_ary.dtype)
  115. print(rotation_ary.round(3))
  116. print('inscale_ary', inscale_ary.shape, inscale_ary.dtype)
  117. print(inscale_ary.round(3))
  118. print('intrans_ary', intrans_ary.shape, intrans_ary.dtype)
  119. print(intrans_ary.round(3))
  120. print('transform_ary', transform_ary.shape, transform_ary.dtype)
  121. print(transform_ary.round(3))
  122. print('grid_ary', grid_ary.shape, grid_ary.dtype)
  123. print(grid_ary.round(3))
  124. def prtf(pt):
  125. print(pt, 'transformed', (transform_ary @ (pt + [1]))[:2].round(3))
  126. prtf([0, 0])
  127. prtf([1, 0])
  128. prtf([2, 0])
  129. print('')
  130. prtf([0, 0])
  131. prtf([0, 1])
  132. prtf([0, 2])
  133. prtf(output_center[2:])
  134. return transform_t, transform_ary, grid_ary
  135. def _buildEquivalentTransforms3d(device, input_size, output_size, angle_rad, axis_vector):
  136. print("_buildEquivalentTransforms2d", device, input_size, output_size, angle_rad * 180 / math.pi, axis_vector)
  137. input_center = [(x-1)/2 for x in input_size]
  138. output_center = [(x-1)/2 for x in output_size]
  139. s = math.sin(angle_rad)
  140. c = math.cos(angle_rad)
  141. c1 = 1 - c
  142. intrans_ary = np.array([
  143. [1, 0, 0, input_center[2]],
  144. [0, 1, 0, input_center[3]],
  145. [0, 0, 1, input_center[4]],
  146. [0, 0, 0, 1],
  147. ], dtype=np.float64)
  148. inscale_ary = np.array([
  149. [input_center[2], 0, 0, 0],
  150. [0, input_center[3], 0, 0],
  151. [0, 0, input_center[4], 0],
  152. [0, 0, 0, 1],
  153. ], dtype=np.float64)
  154. l, m, n = axis_vector
  155. scipyRotation_ary = np.array([
  156. [l*l*c1 + c, m*l*c1 - n*s, n*l*c1 + m*s, 0],
  157. [l*m*c1 + n*s, m*m*c1 + c, n*m*c1 - l*s, 0],
  158. [l*n*c1 - m*s, m*n*c1 + l*s, n*n*c1 + c, 0],
  159. [0, 0, 0, 1],
  160. ], dtype=np.float64)
  161. z, y, x = axis_vector
  162. torchRotation_ary = np.array([
  163. [x*x*c1 + c, y*x*c1 - z*s, z*x*c1 + y*s, 0],
  164. [x*y*c1 + z*s, y*y*c1 + c, z*y*c1 - x*s, 0],
  165. [x*z*c1 - y*s, y*z*c1 + x*s, z*z*c1 + c, 0],
  166. [0, 0, 0, 1],
  167. ], dtype=np.float64)
  168. outscale_ary = np.array([
  169. [1/output_center[2], 0, 0, 0],
  170. [0, 1/output_center[3], 0, 0],
  171. [0, 0, 1/output_center[4], 0],
  172. [0, 0, 0, 1],
  173. ], dtype=np.float64)
  174. outtrans_ary = np.array([
  175. [1, 0, 0, -output_center[2]],
  176. [0, 1, 0, -output_center[3]],
  177. [0, 0, 1, -output_center[4]],
  178. [0, 0, 0, 1],
  179. ], dtype=np.float64)
  180. reorder_ary = np.array([
  181. [0, 0, 1, 0],
  182. [0, 1, 0, 0],
  183. [1, 0, 0, 0],
  184. [0, 0, 0, 1],
  185. ], dtype=np.float64)
  186. transform_ary = intrans_ary @ inscale_ary @ np.linalg.inv(scipyRotation_ary) @ outscale_ary @ outtrans_ary
  187. grid_ary = reorder_ary @ np.linalg.inv(scipyRotation_ary) @ outscale_ary @ outtrans_ary
  188. transform_t = torch.from_numpy((torchRotation_ary)).to(device, torch.float32)
  189. transform_t = transform_t[:3].unsqueeze(0)
  190. print('transform_t', transform_t.size(), transform_t.dtype, transform_t.device)
  191. print(transform_t)
  192. print('outtrans_ary', outtrans_ary.shape, outtrans_ary.dtype)
  193. print(outtrans_ary.round(3))
  194. print('outscale_ary', outscale_ary.shape, outscale_ary.dtype)
  195. print(outscale_ary.round(3))
  196. print('rotation_ary', scipyRotation_ary.shape, scipyRotation_ary.dtype, axis_vector, angle_rad)
  197. print(scipyRotation_ary.round(3))
  198. print('inscale_ary', inscale_ary.shape, inscale_ary.dtype)
  199. print(inscale_ary.round(3))
  200. print('intrans_ary', intrans_ary.shape, intrans_ary.dtype)
  201. print(intrans_ary.round(3))
  202. print('transform_ary', transform_ary.shape, transform_ary.dtype)
  203. print(transform_ary.round(3))
  204. print('grid_ary', grid_ary.shape, grid_ary.dtype)
  205. print(grid_ary.round(3))
  206. def prtf(pt):
  207. print(pt, 'transformed', (transform_ary @ (pt + [1]))[:3].round(3))
  208. prtf([0, 0, 0])
  209. prtf([1, 0, 0])
  210. prtf([2, 0, 0])
  211. print('')
  212. prtf([0, 0, 0])
  213. prtf([0, 1, 0])
  214. prtf([0, 2, 0])
  215. print('')
  216. prtf([0, 0, 0])
  217. prtf([0, 0, 1])
  218. prtf([0, 0, 2])
  219. prtf(output_center[2:])
  220. return transform_t, transform_ary, grid_ary
  221. def test_affine_2d_rotate0(device, affine_func2d):
  222. input_size = [1, 1, 3, 3]
  223. input_ary = np.array(np.random.random(input_size), dtype=np.float32)
  224. output_size = [1, 1, 5, 5]
  225. angle_rad = 0.
  226. transform_t, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
  227. # reference
  228. # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
  229. scipy_ary = scipy.ndimage.affine_transform(
  230. input_ary[0,0],
  231. transform_ary,
  232. offset=offset,
  233. output_shape=output_size[2:],
  234. # output=None,
  235. order=1,
  236. mode='nearest',
  237. # cval=0.0,
  238. prefilter=False)
  239. print('input_ary', input_ary.shape, input_ary.dtype)
  240. print(input_ary)
  241. print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
  242. print(scipy_ary)
  243. affine_t = affine_func2d(
  244. transform_t,
  245. torch.Size(output_size)
  246. )
  247. print('affine_t', affine_t.size(), affine_t.dtype, affine_t.device)
  248. print(affine_t)
  249. gridsample_ary = torch.nn.functional.grid_sample(
  250. torch.tensor(input_ary, device=device).to(device),
  251. affine_t,
  252. padding_mode='border'
  253. ).to('cpu').numpy()
  254. print('input_ary', input_ary.shape, input_ary.dtype)
  255. print(input_ary)
  256. print('gridsample_ary', gridsample_ary.shape, gridsample_ary.dtype)
  257. print(gridsample_ary)
  258. print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
  259. print(scipy_ary)
  260. assert np.abs(scipy_ary.mean() - gridsample_ary.mean()) < 1e-6
  261. assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6
  262. # assert False
  263. def test_affine_2d_rotate90(device, affine_func2d, input_size2dsq, output_size2dsq):
  264. input_size = input_size2dsq
  265. input_ary = np.array(np.random.random(input_size), dtype=np.float32)
  266. output_size = output_size2dsq
  267. angle_rad = 0.25 * math.pi * 2
  268. transform_t, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
  269. # reference
  270. # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
  271. scipy_ary = scipy.ndimage.affine_transform(
  272. input_ary[0,0],
  273. transform_ary,
  274. offset=offset,
  275. output_shape=output_size[2:],
  276. # output=None,
  277. order=1,
  278. mode='nearest',
  279. # cval=0.0,
  280. prefilter=True)
  281. print('input_ary', input_ary.shape, input_ary.dtype, input_ary.mean())
  282. print(input_ary)
  283. print('scipy_ary', scipy_ary.shape, scipy_ary.dtype, scipy_ary.mean())
  284. print(scipy_ary)
  285. if input_size2dsq == output_size2dsq:
  286. assert np.abs(scipy_ary.mean() - input_ary.mean()) < 1e-6
  287. assert np.abs(scipy_ary[0,0] - input_ary[0,0,0,-1]).max() < 1e-6
  288. assert np.abs(scipy_ary[0,-1] - input_ary[0,0,-1,-1]).max() < 1e-6
  289. assert np.abs(scipy_ary[-1,-1] - input_ary[0,0,-1,0]).max() < 1e-6
  290. assert np.abs(scipy_ary[-1,0] - input_ary[0,0,0,0]).max() < 1e-6
  291. affine_t = affine_func2d(
  292. transform_t,
  293. torch.Size(output_size)
  294. )
  295. print('affine_t', affine_t.size(), affine_t.dtype, affine_t.device)
  296. print(affine_t)
  297. gridsample_ary = torch.nn.functional.grid_sample(
  298. torch.tensor(input_ary, device=device).to(device),
  299. affine_t,
  300. padding_mode='border'
  301. ).to('cpu').numpy()
  302. print('input_ary', input_ary.shape, input_ary.dtype)
  303. print(input_ary)
  304. print('gridsample_ary', gridsample_ary.shape, gridsample_ary.dtype)
  305. print(gridsample_ary)
  306. print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
  307. print(scipy_ary)
  308. assert np.abs(scipy_ary.mean() - gridsample_ary.mean()) < 1e-6
  309. assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6
  310. # assert False
  311. def test_affine_2d_rotate45(device, affine_func2d):
  312. input_size = [1, 1, 3, 3]
  313. input_ary = np.array(np.zeros(input_size), dtype=np.float32)
  314. input_ary[0,0,0,:] = 0.5
  315. input_ary[0,0,2,2] = 1.0
  316. output_size = [1, 1, 3, 3]
  317. angle_rad = 0.125 * math.pi * 2
  318. transform_t, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
  319. # reference
  320. # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
  321. scipy_ary = scipy.ndimage.affine_transform(
  322. input_ary[0,0],
  323. transform_ary,
  324. offset=offset,
  325. output_shape=output_size[2:],
  326. # output=None,
  327. order=1,
  328. mode='nearest',
  329. # cval=0.0,
  330. prefilter=False)
  331. print('input_ary', input_ary.shape, input_ary.dtype)
  332. print(input_ary)
  333. print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
  334. print(scipy_ary)
  335. affine_t = affine_func2d(
  336. transform_t,
  337. torch.Size(output_size)
  338. )
  339. print('affine_t', affine_t.size(), affine_t.dtype, affine_t.device)
  340. print(affine_t)
  341. gridsample_ary = torch.nn.functional.grid_sample(
  342. torch.tensor(input_ary, device=device).to(device),
  343. affine_t,
  344. padding_mode='border'
  345. ).to('cpu').numpy()
  346. print('input_ary', input_ary.shape, input_ary.dtype)
  347. print(input_ary)
  348. print('gridsample_ary', gridsample_ary.shape, gridsample_ary.dtype)
  349. print(gridsample_ary)
  350. print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
  351. print(scipy_ary)
  352. assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6
  353. # assert False
  354. def test_affine_2d_rotateRandom(device, affine_func2d, angle_rad, input_size2d, output_size2d):
  355. input_size = input_size2d
  356. input_ary = np.array(np.random.random(input_size), dtype=np.float32).round(3)
  357. output_size = output_size2d
  358. input_ary[0,0,0,0] = 2
  359. input_ary[0,0,0,-1] = 4
  360. input_ary[0,0,-1,0] = 6
  361. input_ary[0,0,-1,-1] = 8
  362. transform_t, transform_ary, grid_ary = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
  363. # reference
  364. # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
  365. scipy_ary = scipy.ndimage.affine_transform(
  366. input_ary[0,0],
  367. transform_ary,
  368. # offset=offset,
  369. output_shape=output_size[2:],
  370. # output=None,
  371. order=1,
  372. mode='nearest',
  373. # cval=0.0,
  374. prefilter=False)
  375. affine_t = affine_func2d(
  376. transform_t,
  377. torch.Size(output_size)
  378. )
  379. print('affine_t', affine_t.size(), affine_t.dtype, affine_t.device)
  380. print(affine_t)
  381. for r in range(affine_t.size(1)):
  382. for c in range(affine_t.size(2)):
  383. grid_out = grid_ary @ [r, c, 1]
  384. print(r, c, 'affine:', affine_t[0,r,c], 'grid:', grid_out[:2])
  385. gridsample_ary = torch.nn.functional.grid_sample(
  386. torch.tensor(input_ary, device=device).to(device),
  387. affine_t,
  388. padding_mode='border'
  389. ).to('cpu').numpy()
  390. print('input_ary', input_ary.shape, input_ary.dtype)
  391. print(input_ary.round(3))
  392. print('gridsample_ary', gridsample_ary.shape, gridsample_ary.dtype)
  393. print(gridsample_ary.round(3))
  394. print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
  395. print(scipy_ary.round(3))
  396. for r in range(affine_t.size(1)):
  397. for c in range(affine_t.size(2)):
  398. grid_out = grid_ary @ [r, c, 1]
  399. try:
  400. assert np.allclose(affine_t[0,r,c], grid_out[:2], atol=1e-5)
  401. except:
  402. print(r, c, 'affine:', affine_t[0,r,c], 'grid:', grid_out[:2])
  403. raise
  404. assert np.abs(scipy_ary - gridsample_ary).max() < 1e-5
  405. # assert False
  406. def test_affine_3d_rotateRandom(device, affine_func3d, angle_rad, axis_vector, input_size3d, output_size3d):
  407. input_size = input_size3d
  408. input_ary = np.array(np.random.random(input_size), dtype=np.float32)
  409. output_size = output_size3d
  410. input_ary[0,0, 0, 0, 0] = 2
  411. input_ary[0,0, 0, 0, -1] = 3
  412. input_ary[0,0, 0, -1, 0] = 4
  413. input_ary[0,0, 0, -1, -1] = 5
  414. input_ary[0,0, -1, 0, 0] = 6
  415. input_ary[0,0, -1, 0, -1] = 7
  416. input_ary[0,0, -1, -1, 0] = 8
  417. input_ary[0,0, -1, -1, -1] = 9
  418. transform_t, transform_ary, grid_ary = _buildEquivalentTransforms3d(device, input_size, output_size, angle_rad, axis_vector)
  419. # reference
  420. # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
  421. scipy_ary = scipy.ndimage.affine_transform(
  422. input_ary[0,0],
  423. transform_ary,
  424. # offset=offset,
  425. output_shape=output_size[2:],
  426. # output=None,
  427. order=1,
  428. mode='nearest',
  429. # cval=0.0,
  430. prefilter=False)
  431. affine_t = affine_func3d(
  432. transform_t,
  433. torch.Size(output_size)
  434. )
  435. print('affine_t', affine_t.size(), affine_t.dtype, affine_t.device)
  436. print(affine_t)
  437. for i in range(affine_t.size(1)):
  438. for r in range(affine_t.size(2)):
  439. for c in range(affine_t.size(3)):
  440. grid_out = grid_ary @ [i, r, c, 1]
  441. print(i, r, c, 'affine:', affine_t[0,i,r,c], 'grid:', grid_out[:3].round(3))
  442. print('input_ary', input_ary.shape, input_ary.dtype)
  443. print(input_ary.round(3))
  444. gridsample_ary = torch.nn.functional.grid_sample(
  445. torch.tensor(input_ary, device=device).to(device),
  446. affine_t,
  447. padding_mode='border'
  448. ).to('cpu').numpy()
  449. print('gridsample_ary', gridsample_ary.shape, gridsample_ary.dtype)
  450. print(gridsample_ary.round(3))
  451. print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
  452. print(scipy_ary.round(3))
  453. for i in range(affine_t.size(1)):
  454. for r in range(affine_t.size(2)):
  455. for c in range(affine_t.size(3)):
  456. grid_out = grid_ary @ [i, r, c, 1]
  457. try:
  458. assert np.allclose(affine_t[0,i,r,c], grid_out[:3], atol=1e-5)
  459. except:
  460. print(i, r, c, 'affine:', affine_t[0,i,r,c], 'grid:', grid_out[:3].round(3))
  461. raise
  462. assert np.abs(scipy_ary - gridsample_ary).max() < 1e-5
  463. # assert False