affine.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import torch
  2. from torch.autograd import Function
  3. from torch.autograd.function import once_differentiable
  4. import torch.backends.cudnn as cudnn
  5. from util.logconf import logging
  6. log = logging.getLogger(__name__)
  7. # log.setLevel(logging.WARN)
  8. # log.setLevel(logging.INFO)
  9. log.setLevel(logging.DEBUG)
  10. def affine_grid_generator(theta, size):
  11. if theta.data.is_cuda and len(size) == 4:
  12. if not cudnn.enabled:
  13. raise RuntimeError("AffineGridGenerator needs CuDNN for "
  14. "processing CUDA inputs, but CuDNN is not enabled")
  15. if not cudnn.is_acceptable(theta.data):
  16. raise RuntimeError("AffineGridGenerator generator theta not acceptable for CuDNN")
  17. N, C, H, W = size
  18. return torch.cudnn_affine_grid_generator(theta, N, C, H, W)
  19. else:
  20. return AffineGridGenerator.apply(theta, size)
  21. class AffineGridGenerator(Function):
  22. @staticmethod
  23. def _enforce_cudnn(input):
  24. if not cudnn.enabled:
  25. raise RuntimeError("AffineGridGenerator needs CuDNN for "
  26. "processing CUDA inputs, but CuDNN is not enabled")
  27. assert cudnn.is_acceptable(input)
  28. @staticmethod
  29. def forward(ctx, theta, size):
  30. assert type(size) == torch.Size
  31. if len(size) == 5:
  32. N, C, D, H, W = size
  33. ctx.size = size
  34. ctx.is_cuda = theta.is_cuda
  35. base_grid = theta.new(N, D, H, W, 4)
  36. w_points = (torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1]))
  37. h_points = (torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])).unsqueeze(-1)
  38. d_points = (torch.linspace(-1, 1, D) if D > 1 else torch.Tensor([-1])).unsqueeze(-1).unsqueeze(-1)
  39. base_grid[:, :, :, :, 0] = w_points
  40. base_grid[:, :, :, :, 1] = h_points
  41. base_grid[:, :, :, :, 2] = d_points
  42. base_grid[:, :, :, :, 3] = 1
  43. ctx.base_grid = base_grid
  44. grid = torch.bmm(base_grid.view(N, D * H * W, 4), theta.transpose(1, 2))
  45. grid = grid.view(N, D, H, W, 3)
  46. elif len(size) == 4:
  47. N, C, H, W = size
  48. ctx.size = size
  49. if theta.is_cuda:
  50. AffineGridGenerator._enforce_cudnn(theta)
  51. assert False
  52. ctx.is_cuda = False
  53. base_grid = theta.new(N, H, W, 3)
  54. linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])
  55. base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0])
  56. linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])
  57. base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1])
  58. base_grid[:, :, :, 2] = 1
  59. ctx.base_grid = base_grid
  60. grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2))
  61. grid = grid.view(N, H, W, 2)
  62. else:
  63. raise RuntimeError("AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs.")
  64. return grid
  65. @staticmethod
  66. @once_differentiable
  67. def backward(ctx, grad_grid):
  68. if len(ctx.size) == 5:
  69. N, C, D, H, W = ctx.size
  70. assert grad_grid.size() == torch.Size([N, D, H, W, 3])
  71. assert ctx.is_cuda == grad_grid.is_cuda
  72. # if grad_grid.is_cuda:
  73. # AffineGridGenerator._enforce_cudnn(grad_grid)
  74. # assert False
  75. base_grid = ctx.base_grid
  76. grad_theta = torch.bmm(
  77. base_grid.view(N, D * H * W, 4).transpose(1, 2),
  78. grad_grid.view(N, D * H * W, 3))
  79. grad_theta = grad_theta.transpose(1, 2)
  80. elif len(ctx.size) == 4:
  81. N, C, H, W = ctx.size
  82. assert grad_grid.size() == torch.Size([N, H, W, 2])
  83. assert ctx.is_cuda == grad_grid.is_cuda
  84. if grad_grid.is_cuda:
  85. AffineGridGenerator._enforce_cudnn(grad_grid)
  86. assert False
  87. base_grid = ctx.base_grid
  88. grad_theta = torch.bmm(
  89. base_grid.view(N, H * W, 3).transpose(1, 2),
  90. grad_grid.view(N, H * W, 2))
  91. grad_theta = grad_theta.transpose(1, 2)
  92. else:
  93. assert False
  94. return grad_theta, None