{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import re\n", "def xprint(s):\n", " s = str(s)\n", " s = re.sub(' *#.*','',s)\n", " print(s)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def myfn(x):\n", " y = x[0]\n", " for i in range(1, x.size(0)):\n", " y = y + x[i]\n", " return y" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def myfn(x: Tensor) -> Tensor:\n", " y = torch.select(x, 0, 0)\n", " y0 = torch.add(y, torch.select(x, 0, 1), alpha=1)\n", " y1 = torch.add(y0, torch.select(x, 0, 2), alpha=1)\n", " y2 = torch.add(y1, torch.select(x, 0, 3), alpha=1)\n", " _0 = torch.add(y2, torch.select(x, 0, 4), alpha=1)\n", " return _0\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/lib/python3/dist-packages/ipykernel_launcher.py:3: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", " This is separate from the ipykernel package so we can avoid doing imports until\n" ] } ], "source": [ "inp = torch.randn(5,5)\n", "traced_fn = torch.jit.trace(myfn, inp)\n", "print(traced_fn.code)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def myfn(x: Tensor) -> Tensor:\n", " y = torch.select(x, 0, 0)\n", " _0 = torch.__range_length(1, torch.size(x, 0), 1)\n", " y0 = y\n", " for _1 in range(_0):\n", " i = torch.__derive_index(_1, 1, 1)\n", " y0 = torch.add(y0, torch.select(x, 0, i), alpha=1)\n", " return y0\n", "\n" ] } ], "source": [ "scripted_fn = torch.jit.script(myfn)\n", "print(scripted_fn.code)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "graph(%x.1 : Tensor):\n", " %10 : bool = prim::Constant[value=1]()\n", " %2 : int = prim::Constant[value=0]()\n", " %5 : int = prim::Constant[value=1]()\n", " %y.1 : Tensor = aten::select(%x.1, %2, %2)\n", " %7 : int = aten::size(%x.1, %2)\n", " %9 : int = aten::__range_length(%5, %7, %5)\n", " %y : Tensor = prim::Loop(%9, %10, %y.1)\n", " block0(%11 : int, %y.6 : Tensor):\n", " %i.1 : int = aten::__derive_index(%11, %5, %5)\n", " %18 : Tensor = aten::select(%x.1, %2, %i.1)\n", " %y.3 : Tensor = aten::add(%y.6, %18, %5)\n", " -> (%10, %y.3)\n", " return (%y)\n", "\n" ] } ], "source": [ "xprint(scripted_fn.graph)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('..')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from p2ch13.model_seg import UNetWrapper" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "seg_dict = torch.load('../data-unversioned/part2/models/p2ch13/seg_2019-10-20_15.57.21_none.best.state', map_location='cpu')\n", "seg_model = UNetWrapper(in_channels=8, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')\n", "seg_model.load_state_dict(seg_dict['model_state'])\n", "seg_model.eval()\n", "for p in seg_model.parameters():\n", " p.requires_grad_(False)\n", "\n", "traced_seg_model = torch.jit.trace()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }