cyclegan_jit.cpp 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. // tag::part1[]
  2. #include "torch/script.h" // <1>
  3. #define cimg_use_jpeg
  4. #include "CImg.h"
  5. using namespace cimg_library;
  6. int main(int argc, char **argv) {
  7. // end::part1[]
  8. if (argc != 4) {
  9. std::cerr << "Call as " << argv[0] << " model.pt input.jpg output.jpg"
  10. << std::endl;
  11. return 1;
  12. }
  13. // tag::part2[]
  14. CImg<float> image(argv[2]); // <2>
  15. image = image.resize(227, 227); // <3>
  16. // end::part2[]
  17. // tag::part3[]
  18. auto input_ = torch::tensor(
  19. torch::ArrayRef<float>(image.data(), image.size())); // <1>
  20. auto input = input_.reshape({1, 3, image.height(),
  21. image.width()}).div_(255); // <2>
  22. auto module = torch::jit::load(argv[1]); // <3>
  23. std::vector<torch::jit::IValue> inputs; // <4>
  24. inputs.push_back(input);
  25. auto output_ = module.forward(inputs).toTensor(); // <5>
  26. auto output = output_.contiguous().mul_(255); // <6>
  27. // end::part3[]
  28. // tag::part4[]
  29. CImg<float> out_img(output.data_ptr<float>(), output.size(2), // <4>
  30. output.size(3), 1, output.size(1));
  31. out_img.save(argv[3]); // <5>
  32. return 0;
  33. }
  34. // end::part4[]