{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "../c10/core/TensorImpl.h:860: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable.\n" ] } ], "source": [ "import torch\n", "_ = torch.tensor([0.2126, 0.7152, 0.0722], names=['c'])" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "img_t = torch.randn(3, 5, 5) # shape [channels, rows, columns]\n", "weights = torch.tensor([0.2126, 0.7152, 0.0722])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "batch_t = torch.randn(2, 3, 5, 5) # shape [batch, channels, rows, columns]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([5, 5]), torch.Size([2, 5, 5]))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img_gray_naive = img_t.mean(-3)\n", "batch_gray_naive = batch_t.mean(-3)\n", "img_gray_naive.shape, batch_gray_naive.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([2, 3, 5, 5]), torch.Size([2, 3, 5, 5]), torch.Size([3, 1, 1]))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "unsqueezed_weights = weights.unsqueeze(-1).unsqueeze_(-1)\n", "img_weights = (img_t * unsqueezed_weights)\n", "batch_weights = (batch_t * unsqueezed_weights)\n", "img_gray_weighted = img_weights.sum(-3)\n", "batch_gray_weighted = batch_weights.sum(-3)\n", "batch_weights.shape, batch_t.shape, unsqueezed_weights.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 5, 5])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)\n", "batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)\n", "batch_gray_weighted_fancy.shape" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.2126, 0.7152, 0.0722], names=('channels',))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weights_named = torch.tensor([0.2126, 0.7152, 0.0722], names=['channels'])\n", "weights_named" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "img named: torch.Size([3, 5, 5]) ('channels', 'rows', 'columns')\n", "batch named: torch.Size([2, 3, 5, 5]) (None, 'channels', 'rows', 'columns')\n" ] } ], "source": [ "img_named = img_t.refine_names(..., 'channels', 'rows', 'columns')\n", "batch_named = batch_t.refine_names(..., 'channels', 'rows', 'columns')\n", "print(\"img named:\", img_named.shape, img_named.names)\n", "print(\"batch named:\", batch_named.shape, batch_named.names)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([3, 1, 1]), ('channels', 'rows', 'columns'))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weights_aligned = weights_named.align_as(img_named)\n", "weights_aligned.shape, weights_aligned.names" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([5, 5]), ('rows', 'columns'))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gray_named = (img_named * weights_aligned).sum('channels')\n", "gray_named.shape, gray_named.names" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Error when attempting to broadcast dims ['channels', 'rows', 'columns'] and dims ['channels']: dim 'columns' and dim 'channels' are at the same position from the right but do not match.\n" ] } ], "source": [ "try:\n", " gray_named = (img_named[..., :3] * weights_named).sum('channels')\n", "except Exception as e:\n", " print(e)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([5, 5]), (None, None))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gray_plain = gray_named.rename(None)\n", "gray_plain.shape, gray_plain.names" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }