cyclegan_cpp_api.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. // tag::header[]
  2. #include <torch/torch.h> // <1>
  3. #define cimg_use_jpeg
  4. #include <CImg.h>
  5. using torch::Tensor; // <2>
  6. // end::header[]
  7. // at the time of writing this code (shortly after PyTorch 1.3),
  8. // the C++ api wasn't complete and (in the case of ReLU) bug-free,
  9. // so we define some Modules ad-hoc here.
  10. // Chances are, that you can take standard models if and when
  11. // they are done.
  12. struct ConvTranspose2d : torch::nn::Module {
  13. // we don't do any of the running stats business
  14. std::vector<int64_t> stride_;
  15. std::vector<int64_t> padding_;
  16. std::vector<int64_t> output_padding_;
  17. std::vector<int64_t> dilation_;
  18. Tensor weight;
  19. Tensor bias;
  20. ConvTranspose2d(int64_t in_channels, int64_t out_channels,
  21. int64_t kernel_size, int64_t stride, int64_t padding,
  22. int64_t output_padding)
  23. : stride_(2, stride), padding_(2, padding),
  24. output_padding_(2, output_padding), dilation_(2, 1) {
  25. // not good init...
  26. weight = register_parameter(
  27. "weight",
  28. torch::randn({out_channels, in_channels, kernel_size, kernel_size}));
  29. bias = register_parameter("bias", torch::randn({out_channels}));
  30. }
  31. Tensor forward(const Tensor &inp) {
  32. return conv_transpose2d(inp, weight, bias, stride_, padding_,
  33. output_padding_, /*groups=*/1, dilation_);
  34. }
  35. };
  36. // tag::block[]
  37. struct ResNetBlock : torch::nn::Module {
  38. torch::nn::Sequential conv_block;
  39. ResNetBlock(int64_t dim)
  40. : conv_block( // <1>
  41. torch::nn::ReflectionPad2d(1),
  42. torch::nn::Conv2d(torch::nn::Conv2dOptions(dim, dim, 3)),
  43. torch::nn::InstanceNorm2d(
  44. torch::nn::InstanceNorm2dOptions(dim)),
  45. torch::nn::ReLU(/*inplace=*/true),
  46. torch::nn::ReflectionPad2d(1),
  47. torch::nn::Conv2d(torch::nn::Conv2dOptions(dim, dim, 3)),
  48. torch::nn::InstanceNorm2d(
  49. torch::nn::InstanceNorm2dOptions(dim))) {
  50. register_module("conv_block", conv_block); // <2>
  51. }
  52. Tensor forward(const Tensor &inp) {
  53. return inp + conv_block->forward(inp); // <3>
  54. }
  55. };
  56. // end::block[]
  57. // tag::generator1[]
  58. struct ResNetGeneratorImpl : torch::nn::Module {
  59. torch::nn::Sequential model;
  60. ResNetGeneratorImpl(int64_t input_nc = 3, int64_t output_nc = 3,
  61. int64_t ngf = 64, int64_t n_blocks = 9) {
  62. TORCH_CHECK(n_blocks >= 0);
  63. model->push_back(torch::nn::ReflectionPad2d(3)); // <1>
  64. // end::generator1[]
  65. model->push_back(
  66. torch::nn::Conv2d(torch::nn::Conv2dOptions(input_nc, ngf, 7)));
  67. model->push_back(
  68. torch::nn::InstanceNorm2d(torch::nn::InstanceNorm2dOptions(7)));
  69. model->push_back(torch::nn::ReLU(/*inplace=*/true));
  70. constexpr int64_t n_downsampling = 2;
  71. for (int64_t i = 0; i < n_downsampling; i++) {
  72. int64_t mult = 1 << i;
  73. // tag::generator2[]
  74. model->push_back(torch::nn::Conv2d(
  75. torch::nn::Conv2dOptions(ngf * mult, ngf * mult * 2, 3)
  76. .stride(2)
  77. .padding(1))); // <3>
  78. // end::generator2[]
  79. model->push_back(torch::nn::InstanceNorm2d(
  80. torch::nn::InstanceNorm2dOptions(ngf * mult * 2)));
  81. model->push_back(torch::nn::ReLU(/*inplace=*/true));
  82. }
  83. int64_t mult = 1 << n_downsampling;
  84. for (int64_t i = 0; i < n_blocks; i++) {
  85. model->push_back(ResNetBlock(ngf * mult));
  86. }
  87. for (int64_t i = 0; i < n_downsampling; i++) {
  88. int64_t mult = 1 << (n_downsampling - i);
  89. model->push_back(
  90. ConvTranspose2d(ngf * mult, ngf * mult / 2, /*kernel_size=*/3,
  91. /*stride=*/2, /*padding=*/1, /*output_padding=*/1));
  92. model->push_back(torch::nn::InstanceNorm2d(
  93. torch::nn::InstanceNorm2dOptions((ngf * mult / 2))));
  94. model->push_back(torch::nn::ReLU(/*inplace=*/true));
  95. }
  96. model->push_back(torch::nn::ReflectionPad2d(3));
  97. model->push_back(
  98. torch::nn::Conv2d(torch::nn::Conv2dOptions(ngf, output_nc, 7)));
  99. model->push_back(torch::nn::Tanh());
  100. // tag::generator3[]
  101. register_module("model", model);
  102. }
  103. Tensor forward(const Tensor &inp) { return model->forward(inp); }
  104. };
  105. TORCH_MODULE(ResNetGenerator); // <4>
  106. // end::generator3[]
  107. int main(int argc, char **argv) {
  108. // tag::main1[]
  109. ResNetGenerator model; // <1>
  110. // end::main1[]
  111. if (argc != 3) {
  112. std::cerr << "call as " << argv[0] << " model_weights.pt image.jpg"
  113. << std::endl;
  114. return 1;
  115. }
  116. // tag::main2[]
  117. torch::load(model, argv[1]); // <2>
  118. // end::main2[]
  119. // you can print the model structure just like you would in PyTorch
  120. // std::cout << model << std::endl;
  121. // tag::main3[]
  122. cimg_library::CImg<float> image(argv[2]);
  123. image.resize(400, 400);
  124. auto input_ =
  125. torch::tensor(torch::ArrayRef<float>(image.data(), image.size()));
  126. auto input = input_.reshape({1, 3, image.height(), image.width()});
  127. torch::NoGradGuard no_grad; // <3>
  128. model->eval(); // <4>
  129. auto output = model->forward(input); // <5>
  130. // end::main3[]
  131. // tag::main4[]
  132. cimg_library::CImg<float> out_img(output.data_ptr<float>(),
  133. output.size(3), output.size(2),
  134. 1, output.size(1));
  135. cimg_library::CImgDisplay disp(out_img, "See a C++ API zebra!"); // <6>
  136. while (!disp.is_closed()) {
  137. disp.wait();
  138. }
  139. // end::main4[]
  140. return 0;
  141. }