Browse Source

p3ch15 Add missing cpp, etc. files, update p2ch14 code

Eli Stevens 5 years ago
parent
commit
0635650b7f

+ 12 - 4
p2ch11/model.py

@@ -36,13 +36,17 @@ class LunaModel(nn.Module):
                 nn.ConvTranspose2d,
                 nn.ConvTranspose3d,
             }:
-                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
+                nn.init.kaiming_normal_(
+                    m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
+                )
                 if m.bias is not None:
-                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    fan_in, fan_out = \
+                        nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                     bound = 1 / math.sqrt(fan_out)
                     nn.init.normal_(m.bias, -bound, bound)
 
 
+
     def forward(self, input_batch):
         bn_output = self.tail_batchnorm(input_batch)
 
@@ -64,9 +68,13 @@ class LunaBlock(nn.Module):
     def __init__(self, in_channels, conv_channels):
         super().__init__()
 
-        self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.conv1 = nn.Conv3d(
+            in_channels, conv_channels, kernel_size=3, padding=1, bias=True,
+        )
         self.relu1 = nn.ReLU(inplace=True)
-        self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.conv2 = nn.Conv3d(
+            conv_channels, conv_channels, kernel_size=3, padding=1, bias=True,
+        )
         self.relu2 = nn.ReLU(inplace=True)
 
         self.maxpool = nn.MaxPool3d(2, 2)

+ 3 - 1
p2ch14/check_nodule_fp_rate.py

@@ -413,13 +413,15 @@ class FalsePosRateCheckApp:
         clean_a = clean_g.cpu().numpy()
         candidateLabel_a, candidate_count = measure.label(clean_a)
         centerIrc_list = measure.center_of_mass(
-            ct.hu_a + 1001,
+            ct.hu_a.clip(-1000, 1000) + 1001,
             labels=candidateLabel_a,
             index=list(range(1, candidate_count+1)),
         )
 
+
         candidateInfo_list = []
         for i, center_irc in enumerate(centerIrc_list):
+            assert np.isfinite(center_irc).all(), repr([series_uid, i, candidate_count, (ct.hu_a[candidateLabel_a == i+1]).sum(), center_irc])
             center_xyz = irc2xyz(
                 center_irc,
                 ct.origin_xyz,

+ 5 - 5
p2ch14/dsets.py

@@ -381,12 +381,12 @@ class MalignantLunaDataset(LunaDataset):
 
     def __getitem__(self, ndx):
         if self.ratio_int:
-            if ndx % 4 < 2:
-                candidateInfo_tup = self.mal_list[(ndx // 3) % len(self.mal_list)]
-            elif ndx % 4 == 2:
-                candidateInfo_tup = self.ben_list[(ndx // 3) % len(self.ben_list)]
+            if ndx % 2 != 0:
+                candidateInfo_tup = self.mal_list[(ndx // 2) % len(self.mal_list)]
+            elif ndx % 4 == 0:
+                candidateInfo_tup = self.ben_list[(ndx // 4) % len(self.ben_list)]
             else:
-                candidateInfo_tup = self.neg_list[(ndx // 3) % len(self.neg_list)]
+                candidateInfo_tup = self.neg_list[(ndx // 4) % len(self.neg_list)]
         else:
             if ndx >= len(self.ben_list):
                 candidateInfo_tup = self.mal_list[ndx - len(self.ben_list)]

+ 1 - 1
p2ch14/model.py

@@ -20,7 +20,7 @@ def augment3d(inp):
             if random.random() > 0.5:
                 transform_t[i,i] *= -1
         if True: #'offset' in augmentation_dict:
-            offset_float = 0.1 # 8 # augmentation_dict['offset']
+            offset_float = 0.1
             random_float = (random.random() * 2 - 1)
             transform_t[3,i] = offset_float * random_float
     if True:

+ 31 - 34
p2ch14/nodule_analysis.py

@@ -31,16 +31,15 @@ logging.getLogger("p2ch13.dsets").setLevel(logging.WARNING)
 logging.getLogger("p2ch14.dsets").setLevel(logging.WARNING)
 
 def print_confusion(label, confusions, do_mal):
+    row_labels = ['Non-Nodules', 'Benign', 'Malignant']
+
     if do_mal:
-        col_labels = ['', 'Complete Miss', 'Filtered', 'Benign', 'Malignant']
-        row_labels = ['Non-Nodules', 'Benign', 'Malignant']
+        col_labels = ['', 'Complete Miss', 'Filtered Out', 'Pred. Benign', 'Pred. Malignant']
     else:
-        col_labels = ['', 'Complete Miss', 'Filtered', 'Detected']
-        row_labels = ['Non-Nodules', 'Nodules']
-        confusions[-2] += confusions[-1]
+        col_labels = ['', 'Complete Miss', 'Filtered Out', 'Pred. Nodule']
         confusions[:, -2] += confusions[:, -1]
-        confusions = confusions[:-1, :-1]
-    cell_width = 14
+        confusions = confusions[:, :-1]
+    cell_width = 16
     f = '{:>' + str(cell_width) + '}'
     print(label)
     print(' | '.join([f.format(s) for s in col_labels]))
@@ -72,7 +71,7 @@ def match_and_score(detections, truth, threshold=0.5, threshold_mal=0.5):
     confusion = np.zeros((3, 4), dtype=np.int)
     if len(detected_xyz) == 0:
         for tn in true_nodules:
-            confusiion[2 if tn.isMal_bool else 1, 0] += 1
+            confusion[2 if tn.isMal_bool else 1, 0] += 1
     elif len(truth_xyz) == 0:
         for dc in detected_classes:
             confusion[0, dc] += 1
@@ -124,7 +123,7 @@ class NoduleAnalysisApp:
         parser.add_argument('--segmentation-path',
             help="Path to the saved segmentation model",
             nargs='?',
-            default=None,
+            default='data/part2/models/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state',
         )
 
         parser.add_argument('--cls-model',
@@ -135,13 +134,14 @@ class NoduleAnalysisApp:
         parser.add_argument('--classification-path',
             help="Path to the saved classification model",
             nargs='?',
-            default=None,
+            default='data/part2/models/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state',
         )
 
         parser.add_argument('--malignancy-model',
             help="What to model class name to use for the malignancy classifier.",
             action='store',
-            default='ModifiedLunaModel',
+            default='LunaModel',
+            # default='ModifiedLunaModel',
         )
         parser.add_argument('--malignancy-path',
             help="Path to the saved malignancy classification model",
@@ -303,7 +303,6 @@ class NoduleAnalysisApp:
         val_list = sorted(series_set & val_set)
 
 
-        candidateInfo_list = []
         candidateInfo_dict = getCandidateInfoDict()
         series_iter = enumerateWithEstimate(
             val_list + train_list,
@@ -314,10 +313,8 @@ class NoduleAnalysisApp:
             ct = getCt(series_uid)
             mask_a = self.segmentCt(ct, series_uid)
 
-            candidateInfo_list = self.clusterSegmentationOutput(
-                series_uid,
-                ct,
-                mask_a,
+            candidateInfo_list = self.groupSegmentationOutput(
+                series_uid, ct, mask_a
             )
             classifications_list = self.classifyCandidates(ct, candidateInfo_list)
 
@@ -339,7 +336,6 @@ class NoduleAnalysisApp:
         print_confusion("Total", all_confusion, self.malignancy_model is not None)
 
 
-
     def classifyCandidates(self, ct, candidateInfo_list):
         cls_dl = self.initClassificationDl(candidateInfo_list)
         classifications_list = []
@@ -348,22 +344,26 @@ class NoduleAnalysisApp:
 
             input_g = input_t.to(self.device)
             with torch.no_grad():
-                _, probability_g = self.cls_model(input_g)
+                _, probability_nodule_g = self.cls_model(input_g)
                 if self.malignancy_model is not None:
                     _, probability_mal_g = self.malignancy_model(input_g)
                 else:
-                    probability_mal_g = torch.zeros_like(probability_g)
+                    probability_mal_g = torch.zeros_like(probability_nodule_g)
 
-            for center_irc, prob, prob_mal in zip(center_list,
-                                                  probability_g[:,1].tolist(),
-                                                  probability_mal_g[:,1].tolist()
-                                                  ):
+            zip_iter = zip(
+                center_list,
+                probability_nodule_g[:,1].tolist(),
+                probability_mal_g[:,1].tolist(),
+            )
+            for center_irc, prob_nodule, prob_mal in zip_iter:
                 center_xyz = irc2xyz(
                     center_irc,
                     direction_a=ct.direction_a,
                     origin_xyz=ct.origin_xyz,
-                    vxSize_xyz=ct.vxSize_xyz)
-                classifications_list.append((prob, prob_mal, center_xyz, center_irc))
+                    vxSize_xyz=ct.vxSize_xyz,
+                )
+                cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
+                classifications_list.append(cls_tup)
         return classifications_list
 
     def segmentCt(self, ct, series_uid):
@@ -371,26 +371,23 @@ class NoduleAnalysisApp:
             output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
             seg_dl = self.initSegmentationDl(series_uid)
             for batch_tup in seg_dl:
-                input_t = batch_tup[0]
-                ndx_list = batch_tup[4]
+                input_t, label_t, series_list, slice_ndx_list = batch_tup
 
                 input_g = input_t.to(self.device)
                 prediction_g = self.seg_model(input_g)
 
-                for i, sample_ndx in enumerate(ndx_list):
-                    output_a[sample_ndx] = prediction_g[i].cpu().numpy()
+                for i, slice_ndx in enumerate(slice_ndx_list):
+                    output_a[slice_ndx] = prediction_g[i].cpu().numpy()
 
-            # mask_a = output_a > 0.25
             mask_a = output_a > 0.5
-            # mask_a = morphology.binary_erosion(mask_a, iterations=1)
-            # mask_a = morphology.binary_dilation(mask_a, iterations=2)
+            mask_a = morphology.binary_erosion(mask_a, iterations=1)
 
         return mask_a
 
-    def clusterSegmentationOutput(self, series_uid,  ct, clean_a):
+    def groupSegmentationOutput(self, series_uid,  ct, clean_a):
         candidateLabel_a, candidate_count = measurements.label(clean_a)
         centerIrc_list = measurements.center_of_mass(
-            ct.hu_a + 1001,
+            ct.hu_a.clip(-1000, 1000) + 1001,
             labels=candidateLabel_a,
             index=np.arange(1, candidate_count+1),
         )

+ 32 - 6
p2ch14/training.py

@@ -6,6 +6,7 @@ import shutil
 import sys
 
 import numpy as np
+from matplotlib import pyplot
 
 from torch.utils.tensorboard import SummaryWriter
 
@@ -121,11 +122,19 @@ class ClassificationTrainingApp:
 
         if self.cli_args.finetune:
             d = torch.load(self.cli_args.finetune, map_location='cpu')
-            model_blocks = [n for n, subm in model.named_children()
-                            if len(list(subm.parameters())) > 0]
+            model_blocks = [
+                n for n, subm in model.named_children()
+                if len(list(subm.parameters())) > 0
+            ]
             finetune_blocks = model_blocks[-self.cli_args.finetune_depth:]
             log.info(f"finetuning from {self.cli_args.finetune}, blocks {' '.join(finetune_blocks)}")
-            model.load_state_dict(d['model_state'])
+            model.load_state_dict(
+                {
+                    k: v for k,v in d['model_state'].items()
+                    if k.split('.')[0] not in model_blocks[-1]
+                },
+                strict=False,
+            )
             for n, p in model.named_parameters():
                 if n.split('.')[0] not in finetune_blocks:
                     p.requires_grad_(False)
@@ -138,7 +147,7 @@ class ClassificationTrainingApp:
 
     def initOptimizer(self):
         lr = 0.003 if self.cli_args.finetune else 0.001
-        return SGD(self.model.parameters(), weight_decay=1e-4, lr=lr)
+        return SGD(self.model.parameters(), lr=lr, weight_decay=1e-4)
         #return Adam(self.model.parameters(), lr=3e-4)
 
     def initTrainDl(self):
@@ -398,12 +407,21 @@ class ClassificationTrainingApp:
         metrics_dict['pr/f1_score'] = \
             2 * (precision * recall) / (precision + recall)
 
+        threshold = torch.linspace(1, 0)
+        tpr = (metrics_t[None, METRICS_PRED_P_NDX, posLabel_mask] >= threshold[:, None]).sum(1).float() / pos_count
+        fpr = (metrics_t[None, METRICS_PRED_P_NDX, negLabel_mask] >= threshold[:, None]).sum(1).float() / neg_count
+        fp_diff = fpr[1:]-fpr[:-1]
+        tp_avg  = (tpr[1:]+tpr[:-1])/2
+        auc = (fp_diff * tp_avg).sum()
+        metrics_dict['auc'] = auc
+
         log.info(
             ("E{} {:8} {loss/all:.4f} loss, "
                  + "{correct/all:-5.1f}% correct, "
                  + "{pr/precision:.4f} precision, "
                  + "{pr/recall:.4f} recall, "
-                 + "{pr/f1_score:.4f} f1 score"
+                 + "{pr/f1_score:.4f} f1 score, "
+                 + "{auc:.4f} auc"
             ).format(
                 epoch_ndx,
                 mode_str,
@@ -461,6 +479,11 @@ class ClassificationTrainingApp:
             key = key.replace('neg', neg)
             writer.add_scalar(key, value, self.totalTrainingSamples_count)
 
+        fig = pyplot.figure()
+        pyplot.plot(fpr, tpr)
+        writer.add_figure('roc', fig, self.totalTrainingSamples_count)
+
+        writer.add_scalar('auc', auc, self.totalTrainingSamples_count)
 # # tag::logMetrics_writer_prcurve[]
 #        writer.add_pr_curve(
 #            'pr',
@@ -485,7 +508,10 @@ class ClassificationTrainingApp:
             bins=bins
         )
 
-        score = metrics_dict['pr/f1_score']
+        if not self.cli_args.malignant:
+            score = metrics_dict['pr/f1_score']
+        else:
+            score = metrics_dict['auc']
 
         return score
 

+ 27 - 0
p3ch15/CMakeLists.txt

@@ -0,0 +1,27 @@
+cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
+project(cyclegan-jit)
+
+find_package(Torch REQUIRED)
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
+
+add_executable(cyclegan-jit cyclegan_jit.cpp)
+target_link_libraries(cyclegan-jit pthread jpeg X11)
+target_link_libraries(cyclegan-jit "${TORCH_LIBRARIES}")
+set_property(TARGET cyclegan-jit PROPERTY CXX_STANDARD 14)
+
+add_executable(cyclegan-cpp-api cyclegan_cpp_api.cpp)
+target_link_libraries(cyclegan-cpp-api pthread jpeg X11)
+target_link_libraries(cyclegan-cpp-api "${TORCH_LIBRARIES}")
+set_property(TARGET cyclegan-cpp-api PROPERTY CXX_STANDARD 14)
+
+# The following code block is suggested to be used on Windows.
+# According to https://github.com/pytorch/pytorch/issues/25457,
+# the DLLs need to be copied to avoid memory errors.
+if (MSVC)
+  file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
+  add_custom_command(TARGET cyclegan-jit
+                     POST_BUILD
+                     COMMAND ${CMAKE_COMMAND} -E copy_if_different
+                     ${TORCH_DLLS}
+                     $<TARGET_FILE_DIR:example-app>)
+endif (MSVC)

+ 122 - 0
p3ch15/android/MainActivity.java

@@ -0,0 +1,122 @@
+package de.lernapparat.zebraify;
+
+import android.content.Context;
+import android.content.Intent;
+import android.graphics.Bitmap;
+import android.provider.MediaStore;
+import android.support.v7.app.AppCompatActivity;
+import android.os.Bundle;
+import android.util.Log;
+import android.view.View;
+import android.widget.ImageView;
+import android.widget.TextView;
+
+import org.pytorch.IValue;
+import org.pytorch.Module;
+import org.pytorch.Tensor;
+import org.pytorch.torchvision.TensorImageUtils;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+
+public class MainActivity extends AppCompatActivity {
+    static final int REQUEST_IMAGE_CAPTURE = 1;
+    private org.pytorch.Module model;
+
+    @Override
+    protected void onCreate(Bundle savedInstanceState) {
+        super.onCreate(savedInstanceState);
+
+        setContentView(R.layout.activity_main);
+
+        TextView tv= (TextView) findViewById(R.id.headline);
+        tv.setOnClickListener(new View.OnClickListener() {
+            public void onClick(View v) {
+                Intent takePictureIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
+                // takePictureIntent.putExtra(android.provider.MediaStore.EXTRA_OUTPUT, android.provider.MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
+                if (takePictureIntent.resolveActivity(getPackageManager()) != null) {
+                    startActivityForResult(takePictureIntent, REQUEST_IMAGE_CAPTURE);
+                }
+            }
+        });
+
+
+        try {
+            model = Module.load(assetFilePath(this, "traced_zebra_model.pt"));
+        } catch (IOException e) {
+            Log.e("Zebraify", "Error reading assets", e);
+            finish();
+        }
+
+    }
+
+    @Override
+    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
+        if (requestCode == REQUEST_IMAGE_CAPTURE && resultCode == RESULT_OK) {
+            // this gets called when the camera app got a picture
+            Bitmap bitmap = (Bitmap) data.getExtras().get("data");
+
+            final float[] means = {0.0f, 0.0f, 0.0f};
+            final float[] stds = {1.0f, 1.0f, 1.0f};
+            // preparing input tensor
+            final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
+                    means, stds);
+
+            // running the model
+            final Tensor outputTensor = model.forward(IValue.from(inputTensor)).toTensor();
+            Bitmap output_bitmap = tensorToBitmap(outputTensor, means, stds, Bitmap.Config.RGB_565);
+
+            ImageView image_view = (ImageView) findViewById(R.id.imageView);
+            image_view.setImageBitmap(output_bitmap);
+        }
+    }
+
+    // This is intended to be the inverse of bitmapToFloat32Tensor
+    static Bitmap tensorToBitmap(Tensor tensor, float[] normMeanRGB, float[] normStdRGB, Bitmap.Config bc) {
+        final float[] outputArray = tensor.getDataAsFloatArray();
+        final long[] shape = tensor.shape();
+        int width = (int) shape[shape.length - 1];
+        int height = (int) shape[shape.length - 2];
+        Bitmap output_bitmap = Bitmap.createBitmap(width, height, bc);
+
+        int numPixels = width * height;
+        int[] pixels = new int[numPixels];
+        for (int i = 0; i < numPixels; i++) {
+            pixels[i] = ((int) ((outputArray[0 * numPixels + i] * normStdRGB[0] + normMeanRGB[0]) * 255 + 0.49999) << 16)
+                      + ((int) ((outputArray[1 * numPixels + i] * normStdRGB[1] + normMeanRGB[1]) * 255 + 0.49999) << 8)
+                      + ((int) ((outputArray[2 * numPixels + i] * normStdRGB[2] + normMeanRGB[2]) * 255 + 0.49999));
+        }
+        output_bitmap.setPixels(pixels, 0, width, 0, 0, width, height);
+        return output_bitmap;
+    }
+
+    /**
+     * Taken from PyTorch's HelloWorld Android app.
+     *
+     * Copies specified asset to the file in /files app directory and returns this file absolute path.
+     *
+     * @return absolute file path
+     */
+    public static String assetFilePath(Context context, String assetName) throws IOException {
+        File file = new File(context.getFilesDir(), assetName);
+        if (false && file.exists() && file.length() > 0) {
+            return file.getAbsolutePath();
+        }
+
+        try (InputStream is = context.getAssets().open(assetName)) {
+            try (OutputStream os = new FileOutputStream(file, false)) {
+                byte[] buffer = new byte[4 * 1024];
+                int read;
+                while ((read = is.read(buffer)) != -1) {
+                    os.write(buffer, 0, read);
+                }
+                os.flush();
+            }
+            return file.getAbsolutePath();
+        }
+    }
+
+}

+ 32 - 0
p3ch15/android/activity_main.xml

@@ -0,0 +1,32 @@
+<?xml version="1.0" encoding="utf-8"?>
+<android.support.constraint.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
+    xmlns:app="http://schemas.android.com/apk/res-auto"
+    xmlns:tools="http://schemas.android.com/tools"
+    android:layout_width="match_parent"
+    android:layout_height="match_parent"
+    tools:context=".MainActivity">
+
+    <TextView
+        android:id="@+id/headline"
+        android:layout_width="wrap_content"
+        android:layout_height="wrap_content"
+        android:text="Click to run"
+        app:layout_constraintBottom_toBottomOf="parent"
+        app:layout_constraintHorizontal_bias="0.046"
+        app:layout_constraintLeft_toLeftOf="parent"
+        app:layout_constraintRight_toRightOf="parent"
+        app:layout_constraintTop_toTopOf="parent"
+        app:layout_constraintVertical_bias="0.022" />
+
+    <ImageView
+        android:id="@+id/imageView"
+        android:layout_width="0dp"
+        android:layout_height="0dp"
+        android:layout_marginTop="10dp"
+        app:layout_constraintBottom_toBottomOf="parent"
+        app:layout_constraintEnd_toEndOf="parent"
+        app:layout_constraintStart_toStartOf="parent"
+        app:layout_constraintTop_toBottomOf="@+id/headline"
+        tools:srcCompat="@tools:sample/backgrounds/scenic" />
+
+</android.support.constraint.ConstraintLayout>

+ 31 - 0
p3ch15/android/build.gradle

@@ -0,0 +1,31 @@
+apply plugin: 'com.android.application'
+
+android {
+    compileSdkVersion 28
+    buildToolsVersion "29.0.2"
+    defaultConfig {
+        applicationId "de.lernapparat.zebraify"
+        minSdkVersion 23
+        targetSdkVersion 28
+        versionCode 1
+        versionName "1.0"
+        testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+    }
+    buildTypes {
+        release {
+            minifyEnabled false
+            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
+        }
+    }
+}
+
+dependencies {
+    implementation fileTree(dir: 'libs', include: ['*.jar'])
+    implementation 'com.android.support:appcompat-v7:28.0.0'
+    implementation 'com.android.support.constraint:constraint-layout:1.1.3'
+    testImplementation 'junit:junit:4.12'
+    androidTestImplementation 'com.android.support.test:runner:1.0.2'
+    androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
+    implementation 'org.pytorch:pytorch_android:1.3.0'
+    implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'
+}

+ 147 - 0
p3ch15/cyclegan_cpp_api.cpp

@@ -0,0 +1,147 @@
+// tag::header[]
+#include <torch/torch.h>
+#define cimg_use_jpeg
+#include <CImg.h>
+using torch::Tensor;
+// end::header[]
+
+// at the time of writing this code (shortly after PyTorch 1.3),
+// the C++ api wasn't complete and (in the case of ReLU) bug-free,
+// so we define some Modules ad-hoc here.
+// Chances are, that you can take standard models if and when
+// they are done.
+
+struct ConvTranspose2d : torch::nn::Module {
+  // we don't do any of the running stats business
+  std::vector<int64_t> stride_;
+  std::vector<int64_t> padding_;
+  std::vector<int64_t> output_padding_;
+  std::vector<int64_t> dilation_;
+  Tensor weight;
+  Tensor bias;
+
+  ConvTranspose2d(int64_t in_channels, int64_t out_channels,
+                  int64_t kernel_size, int64_t stride, int64_t padding,
+                  int64_t output_padding)
+      : stride_(2, stride), padding_(2, padding),
+        output_padding_(2, output_padding), dilation_(2, 1) {
+    // not good init...
+    weight = register_parameter(
+        "weight",
+        torch::randn({out_channels, in_channels, kernel_size, kernel_size}));
+    bias = register_parameter("bias", torch::randn({out_channels}));
+  }
+  Tensor forward(const Tensor &inp) {
+    return conv_transpose2d(inp, weight, bias, stride_, padding_,
+                            output_padding_, /*groups=*/1, dilation_);
+  }
+};
+
+// tag::block[]
+struct ResNetBlock : torch::nn::Module {
+  torch::nn::Sequential conv_block;
+  ResNetBlock(int64_t dim)
+      : conv_block(  // <1>
+            torch::nn::ReflectionPad2d(1),
+            torch::nn::Conv2d(torch::nn::Conv2dOptions(dim, dim, 3)),
+            torch::nn::InstanceNorm2d(torch::nn::InstanceNorm2dOptions(dim)),
+            torch::nn::ReLU(/*inplace=*/true), torch::nn::ReflectionPad2d(1),
+            torch::nn::Conv2d(torch::nn::Conv2dOptions(dim, dim, 3)),
+            torch::nn::InstanceNorm2d(torch::nn::InstanceNorm2dOptions(dim))) {
+    register_module("conv_block", conv_block); // <2>
+  }
+
+  Tensor forward(const Tensor &inp) {
+    return inp + conv_block->forward(inp); // <3>
+  }
+};
+// end::block[]
+
+// tag::generator1[]
+struct ResNetGeneratorImpl : torch::nn::Module {
+  torch::nn::Sequential model;
+  ResNetGeneratorImpl(int64_t input_nc = 3, int64_t output_nc = 3,
+                      int64_t ngf = 64, int64_t n_blocks = 9) {
+    TORCH_CHECK(n_blocks >= 0);
+    model->push_back(torch::nn::ReflectionPad2d(3)); // <1>
+                                                     // end::generator1[]
+    model->push_back(
+        torch::nn::Conv2d(torch::nn::Conv2dOptions(input_nc, ngf, 7)));
+    model->push_back(
+        torch::nn::InstanceNorm2d(torch::nn::InstanceNorm2dOptions(7)));
+    model->push_back(torch::nn::ReLU(/*inplace=*/true));
+    constexpr int64_t n_downsampling = 2;
+
+    for (int64_t i = 0; i < n_downsampling; i++) {
+      int64_t mult = 1 << i;
+      // tag::generator2[]
+      model->push_back(torch::nn::Conv2d(
+          torch::nn::Conv2dOptions(ngf * mult, ngf * mult * 2, 3)
+              .stride(2)
+              .padding(1))); // <3>
+                             // end::generator2[]
+      model->push_back(torch::nn::InstanceNorm2d(
+          torch::nn::InstanceNorm2dOptions(ngf * mult * 2)));
+      model->push_back(torch::nn::ReLU(/*inplace=*/true));
+    }
+
+    int64_t mult = 1 << n_downsampling;
+    for (int64_t i = 0; i < n_blocks; i++) {
+      model->push_back(ResNetBlock(ngf * mult));
+    }
+    for (int64_t i = 0; i < n_downsampling; i++) {
+      int64_t mult = 1 << (n_downsampling - i);
+      model->push_back(
+          ConvTranspose2d(ngf * mult, ngf * mult / 2, /*kernel_size=*/3,
+                          /*stride=*/2, /*padding=*/1, /*output_padding=*/1));
+      model->push_back(torch::nn::InstanceNorm2d(
+          torch::nn::InstanceNorm2dOptions((ngf * mult / 2))));
+      model->push_back(torch::nn::ReLU(/*inplace=*/true));
+    }
+    model->push_back(torch::nn::ReflectionPad2d(3));
+    model->push_back(
+        torch::nn::Conv2d(torch::nn::Conv2dOptions(ngf, output_nc, 7)));
+    model->push_back(torch::nn::Tanh());
+    // tag::generator3[]
+    register_module("model", model);
+  }
+  Tensor forward(const Tensor &inp) { return model->forward(inp); }
+};
+
+TORCH_MODULE(ResNetGenerator); // <4>
+// end::generator3[]
+
+int main(int argc, char **argv) {
+  // tag::main1[]
+  ResNetGenerator model; // <1>
+                         // end::main1[]
+  if (argc != 3) {
+    std::cerr << "call as " << argv[0] << " model_weights.pt image.jpg"
+              << std::endl;
+    return 1;
+  }
+  // tag::main2[]
+  torch::load(model, argv[1]); // <2>
+                               // end::main2[]
+  // you can print the model structure just like you would in PyTorch
+  // std::cout << model << std::endl;
+  // tag::main3[]
+  cimg_library::CImg<float> image(argv[2]);
+  image.resize(400, 400);
+  auto input_ =
+      torch::tensor(torch::ArrayRef<float>(image.data(), image.size()));
+  auto input = input_.reshape({1, 3, image.height(), image.width()});
+  torch::NoGradGuard no_grad;          // <3>
+  model->eval();                       // <4>
+  auto output = model->forward(input); // <5>
+                                       // end::main3[]
+                                       // tag::main4[]
+  cimg_library::CImg<float> out_img(output.data_ptr<float>(), output.size(3),
+                                    output.size(2), 1, output.size(1));
+  cimg_library::CImgDisplay disp(out_img, "See a C++ API zebra!"); // <6>
+  while (!disp.is_closed()) {
+    disp.wait();
+  }
+  // end::main4[]
+  return 0;
+}

+ 33 - 0
p3ch15/cyclegan_jit.cpp

@@ -0,0 +1,33 @@
+// tag::part1[]
+#include "torch/script.h" // <1>
+#define cimg_use_jpeg
+#include "CImg.h"
+using namespace cimg_library;
+int main(int argc, char **argv) {
+  // end::part1[]
+  if (argc != 4) {
+    std::cerr << "Call as " << argv[0] << " model.pt input.jpg output.jpg"
+              << std::endl;
+    return 1;
+  }
+  // tag::part2[]
+  CImg<float> image(argv[2]); // <2>
+  image = image.resize(227, 227); // <3>
+  // end::part2[]
+  // tag::part3[]
+  auto input_ = torch::tensor(torch::ArrayRef<float>(image.data(),
+                                                     image.size())); // <1>
+  auto input = input_.reshape({1, 3, image.height(), image.width()}).div_(255); // <2>
+  auto module = torch::jit::load(argv[1]); // <3>
+  std::vector<torch::jit::IValue> inputs; // <4>
+  inputs.push_back(input);
+  auto output_ = module.forward(inputs).toTensor(); // <5>
+  auto output = output_.contiguous().mul_(255); // <6>
+// end::part3[]
+// tag::part4[]
+  CImg<float> out_img(output.data_ptr<float>(), output.size(2), // <4>
+                      output.size(3), 1, output.size(1));
+  out_img.save(argv[3]); // <5>
+  return 0;
+}
+// end::part4[]