cyclegan.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # This is the code from the p1ch2/3_cyclegan notebook
  2. import torch
  3. import torch.nn as nn
  4. class ResNetBlock(nn.Module):
  5. def __init__(self, dim):
  6. super(ResNetBlock, self).__init__()
  7. self.conv_block = self.build_conv_block(dim)
  8. def build_conv_block(self, dim):
  9. conv_block = []
  10. conv_block += [nn.ReflectionPad2d(1)]
  11. conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
  12. nn.InstanceNorm2d(dim),
  13. nn.ReLU(True)]
  14. conv_block += [nn.ReflectionPad2d(1)]
  15. conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
  16. nn.InstanceNorm2d(dim)]
  17. return nn.Sequential(*conv_block)
  18. def forward(self, x):
  19. out = x + self.conv_block(x)
  20. return out
  21. class ResNetGenerator(nn.Module):
  22. def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):
  23. assert(n_blocks >= 0)
  24. super(ResNetGenerator, self).__init__()
  25. self.input_nc = input_nc
  26. self.output_nc = output_nc
  27. self.ngf = ngf
  28. model = [nn.ReflectionPad2d(3),
  29. nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
  30. nn.InstanceNorm2d(ngf),
  31. nn.ReLU(True)]
  32. n_downsampling = 2
  33. for i in range(n_downsampling):
  34. mult = 2**i
  35. model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
  36. stride=2, padding=1, bias=True),
  37. nn.InstanceNorm2d(ngf * mult * 2),
  38. nn.ReLU(True)]
  39. mult = 2**n_downsampling
  40. for i in range(n_blocks):
  41. model += [ResNetBlock(ngf * mult)]
  42. for i in range(n_downsampling):
  43. mult = 2**(n_downsampling - i)
  44. model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
  45. kernel_size=3, stride=2,
  46. padding=1, output_padding=1,
  47. bias=True),
  48. nn.InstanceNorm2d(int(ngf * mult / 2)),
  49. nn.ReLU(True)]
  50. model += [nn.ReflectionPad2d(3)]
  51. model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
  52. model += [nn.Tanh()]
  53. self.model = nn.Sequential(*model)
  54. def forward(self, input):
  55. # here we move to 0-1 input and 0-1 output
  56. # usually one would think about writing this differently
  57. # for efficiency (e.g. absorbing the 255 into the first conv
  58. return self.model(input * 255) / 2 + 0.5
  59. def get_pretrained_model(model_path, map_location=None):
  60. netG = ResNetGenerator()
  61. model_data = torch.load(model_path, map_location=map_location)
  62. netG.load_state_dict(model_data)
  63. netG.eval()
  64. for p in netG.parameters():
  65. netG.requires_grad_(False)
  66. return netG
  67. if __name__ == '__main__':
  68. import sys
  69. if len(sys.argv) < 3:
  70. print("Call as {} zebra_weights.pt traced_zebra_model.pt".format(sys.argv[0]))
  71. sys.exit(1)
  72. model = get_pretrained_model(sys.argv[1], map_location='cpu')
  73. traced_model = torch.jit.trace(model, torch.randn(1, 3, 227, 227))
  74. traced_model.save(sys.argv[2])
  75. # img = Image.open("../data/p1ch2/horse.jpg")
  76. # out_img.save('../data/p1ch2/zebra.jpg')