| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- # This is the code from the p1ch2/3_cyclegan notebook
- import torch
- import torch.nn as nn
- class ResNetBlock(nn.Module):
- def __init__(self, dim):
- super(ResNetBlock, self).__init__()
- self.conv_block = self.build_conv_block(dim)
- def build_conv_block(self, dim):
- conv_block = []
- conv_block += [nn.ReflectionPad2d(1)]
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
- nn.InstanceNorm2d(dim),
- nn.ReLU(True)]
- conv_block += [nn.ReflectionPad2d(1)]
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
- nn.InstanceNorm2d(dim)]
- return nn.Sequential(*conv_block)
- def forward(self, x):
- out = x + self.conv_block(x)
- return out
- class ResNetGenerator(nn.Module):
- def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):
- assert(n_blocks >= 0)
- super(ResNetGenerator, self).__init__()
- self.input_nc = input_nc
- self.output_nc = output_nc
- self.ngf = ngf
- model = [nn.ReflectionPad2d(3),
- nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
- nn.InstanceNorm2d(ngf),
- nn.ReLU(True)]
- n_downsampling = 2
- for i in range(n_downsampling):
- mult = 2**i
- model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
- stride=2, padding=1, bias=True),
- nn.InstanceNorm2d(ngf * mult * 2),
- nn.ReLU(True)]
- mult = 2**n_downsampling
- for i in range(n_blocks):
- model += [ResNetBlock(ngf * mult)]
- for i in range(n_downsampling):
- mult = 2**(n_downsampling - i)
- model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
- kernel_size=3, stride=2,
- padding=1, output_padding=1,
- bias=True),
- nn.InstanceNorm2d(int(ngf * mult / 2)),
- nn.ReLU(True)]
- model += [nn.ReflectionPad2d(3)]
- model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
- model += [nn.Tanh()]
- self.model = nn.Sequential(*model)
- def forward(self, input):
- # here we move to 0-1 input and 0-1 output
- # usually one would think about writing this differently
- # for efficiency (e.g. absorbing the 255 into the first conv
- return self.model(input * 255) / 2 + 0.5
- def get_pretrained_model(model_path, map_location=None):
- netG = ResNetGenerator()
- model_data = torch.load(model_path, map_location=map_location)
- netG.load_state_dict(model_data)
- netG.eval()
- for p in netG.parameters():
- netG.requires_grad_(False)
- return netG
- if __name__ == '__main__':
- import sys
- if len(sys.argv) < 3:
- print("Call as {} zebra_weights.pt traced_zebra_model.pt".format(sys.argv[0]))
- sys.exit(1)
- model = get_pretrained_model(sys.argv[1], map_location='cpu')
- traced_model = torch.jit.trace(model, torch.randn(1, 3, 227, 227))
- traced_model.save(sys.argv[2])
- # img = Image.open("../data/p1ch2/horse.jpg")
- # out_img.save('../data/p1ch2/zebra.jpg')
|