{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "class ResNetBlock(nn.Module):\n", "\n", " def __init__(self, dim):\n", " super(ResNetBlock, self).__init__()\n", " self.conv_block = self.build_conv_block(dim)\n", "\n", " def build_conv_block(self, dim):\n", " conv_block = []\n", "\n", " conv_block += [nn.ReflectionPad2d(1)]\n", "\n", " conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),\n", " nn.InstanceNorm2d(dim),\n", " nn.ReLU(True)]\n", "\n", " conv_block += [nn.ReflectionPad2d(1)]\n", "\n", " conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),\n", " nn.InstanceNorm2d(dim)]\n", "\n", " return nn.Sequential(*conv_block)\n", "\n", " def forward(self, x):\n", " out = x + self.conv_block(x)\n", " return out\n", "\n", "\n", "class ResNetGenerator(nn.Module):\n", "\n", " def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):\n", "\n", " assert(n_blocks >= 0)\n", " super(ResNetGenerator, self).__init__()\n", "\n", " self.input_nc = input_nc\n", " self.output_nc = output_nc\n", " self.ngf = ngf\n", "\n", " model = [nn.ReflectionPad2d(3),\n", " nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),\n", " nn.InstanceNorm2d(ngf),\n", " nn.ReLU(True)]\n", "\n", " n_downsampling = 2\n", " for i in range(n_downsampling):\n", " mult = 2**i\n", " model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,\n", " stride=2, padding=1, bias=True),\n", " nn.InstanceNorm2d(ngf * mult * 2),\n", " nn.ReLU(True)]\n", "\n", " mult = 2**n_downsampling\n", " for i in range(n_blocks):\n", " model += [ResNetBlock(ngf * mult)]\n", "\n", " for i in range(n_downsampling):\n", " mult = 2**(n_downsampling - i)\n", " model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),\n", " kernel_size=3, stride=2,\n", " padding=1, output_padding=1,\n", " bias=True),\n", " nn.InstanceNorm2d(int(ngf * mult / 2)),\n", " nn.ReLU(True)]\n", "\n", " model += [nn.ReflectionPad2d(3)]\n", " model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]\n", " model += [nn.Tanh()]\n", "\n", " self.model = nn.Sequential(*model)\n", "\n", " def forward(self, input):\n", " return self.model(input)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "netG = ResNetGenerator()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_path = 'horse2zebra_0.4.0.pth'\n", "model_data = torch.load(model_path)\n", "netG.load_state_dict(model_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "netG.eval()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "from torchvision import transforms" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preprocess = transforms.Compose([transforms.Resize(256),\n", " transforms.ToTensor()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img = Image.open(\"horse.jpg\")\n", "img" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img_t = preprocess(img)\n", "batch_t = torch.unsqueeze(img_t, 0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch_out = netG(batch_t)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "out_t = (batch_out.data.squeeze() + 1.0) / 2.0\n", "out_img = transforms.ToPILImage()(out_t)\n", "# out_img.save('zebra.jpg')\n", "out_img" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }