|
@@ -2,27 +2,11 @@
|
|
|
"cells": [
|
|
"cells": [
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 1,
|
|
|
|
|
|
|
+ "execution_count": null,
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
- "outputs": [
|
|
|
|
|
- {
|
|
|
|
|
- "name": "stdout",
|
|
|
|
|
- "output_type": "stream",
|
|
|
|
|
- "text": [
|
|
|
|
|
- "Requirement already satisfied: torch in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (2.0.1)\n",
|
|
|
|
|
- "Requirement already satisfied: jinja2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch) (2.11.2)\n",
|
|
|
|
|
- "Requirement already satisfied: filelock in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch) (3.0.12)\n",
|
|
|
|
|
- "Requirement already satisfied: sympy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch) (1.6.2)\n",
|
|
|
|
|
- "Requirement already satisfied: typing-extensions in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch) (3.7.4.3)\n",
|
|
|
|
|
- "Requirement already satisfied: networkx in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch) (2.5)\n",
|
|
|
|
|
- "Requirement already satisfied: MarkupSafe>=0.23 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch) (1.1.1)\n",
|
|
|
|
|
- "Requirement already satisfied: mpmath>=0.19 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from sympy->torch) (1.1.0)\n",
|
|
|
|
|
- "Requirement already satisfied: decorator>=4.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from networkx->torch) (4.4.2)\n"
|
|
|
|
|
- ]
|
|
|
|
|
- }
|
|
|
|
|
- ],
|
|
|
|
|
|
|
+ "outputs": [],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "# 安装PyTorch\n",
|
|
|
|
|
|
|
+ "# 安装第三方库\n",
|
|
|
"!pip install torch"
|
|
"!pip install torch"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
@@ -48,8 +32,8 @@
|
|
|
"source": [
|
|
"source": [
|
|
|
"import torch\n",
|
|
"import torch\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- "# 创建tensor\n",
|
|
|
|
|
- "## 使用tensor封装的函数创建tensor\n",
|
|
|
|
|
|
|
+ "# 创建张量(tensor)\n",
|
|
|
|
|
+ "## 使用封装的函数创建张量\n",
|
|
|
"zeros = torch.zeros(2, 3)\n",
|
|
"zeros = torch.zeros(2, 3)\n",
|
|
|
"print(zeros)\n",
|
|
"print(zeros)\n",
|
|
|
"\n",
|
|
"\n",
|
|
@@ -88,7 +72,7 @@
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "# 创建tensor\n",
|
|
|
|
|
|
|
+ "# 创建张量(tensor)\n",
|
|
|
"## 从Python对象创建\n",
|
|
"## 从Python对象创建\n",
|
|
|
"data = [[2, 3, 4], [1, 0, 1]]\n",
|
|
"data = [[2, 3, 4], [1, 0, 1]]\n",
|
|
|
"t_data = torch.tensor(data)\n",
|
|
"t_data = torch.tensor(data)\n",
|
|
@@ -101,7 +85,7 @@
|
|
|
"tn_data = torch.from_numpy(n_data)\n",
|
|
"tn_data = torch.from_numpy(n_data)\n",
|
|
|
"print(tn_data)\n",
|
|
"print(tn_data)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- "## Numpy bridge,也就是对numpy对象的改变会传导到tensor\n",
|
|
|
|
|
|
|
+ "## Numpy bridge,也就是对numpy对象的改变会传导到张量\n",
|
|
|
"n_data += 1\n",
|
|
"n_data += 1\n",
|
|
|
"torch.all(torch.from_numpy(n_data) == tn_data)"
|
|
"torch.all(torch.from_numpy(n_data) == tn_data)"
|
|
|
]
|
|
]
|
|
@@ -124,8 +108,7 @@
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "# 变换tensor维度\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
|
|
+ "# 变换张量维度\n",
|
|
|
"## 增加或减少数据的维度\n",
|
|
"## 增加或减少数据的维度\n",
|
|
|
"a = torch.rand(3, 4)\n",
|
|
"a = torch.rand(3, 4)\n",
|
|
|
"print(a.shape)\n",
|
|
"print(a.shape)\n",
|
|
@@ -161,7 +144,7 @@
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "# 变换tensor形状\n",
|
|
|
|
|
|
|
+ "# 变换张量形状\n",
|
|
|
"data = torch.tensor(range(0, 10))\n",
|
|
"data = torch.tensor(range(0, 10))\n",
|
|
|
"print(data, data.shape)\n",
|
|
"print(data, data.shape)\n",
|
|
|
"view1 = data.view(2, 5)\n",
|
|
"view1 = data.view(2, 5)\n",
|
|
@@ -244,7 +227,7 @@
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "## tensor广播,tensor broadcasting\n",
|
|
|
|
|
|
|
+ "## 广播机制(tensor broadcasting)\n",
|
|
|
"a = torch.tensor(range(1, 7)).view(2, 3)\n",
|
|
"a = torch.tensor(range(1, 7)).view(2, 3)\n",
|
|
|
"b = torch.tensor(range(1, 4)).view( 3)\n",
|
|
"b = torch.tensor(range(1, 4)).view( 3)\n",
|
|
|
"print(a)\n",
|
|
"print(a)\n",
|
|
@@ -301,15 +284,15 @@
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
"# 向量运算\n",
|
|
"# 向量运算\n",
|
|
|
- "# vector x vector\n",
|
|
|
|
|
|
|
+ "# 向量与向量\n",
|
|
|
"vec1 = torch.randn(3)\n",
|
|
"vec1 = torch.randn(3)\n",
|
|
|
"vec2 = torch.randn(3)\n",
|
|
"vec2 = torch.randn(3)\n",
|
|
|
"print((vec1 @ vec2).shape)\n",
|
|
"print((vec1 @ vec2).shape)\n",
|
|
|
- "# matrix x vector\n",
|
|
|
|
|
|
|
+ "# 矩阵与向量\n",
|
|
|
"mat = torch.randn(3, 4)\n",
|
|
"mat = torch.randn(3, 4)\n",
|
|
|
"vec = torch.randn(4)\n",
|
|
"vec = torch.randn(4)\n",
|
|
|
"print((mat @ vec).shape)\n",
|
|
"print((mat @ vec).shape)\n",
|
|
|
- "# batched matrix x broadcasted vector\n",
|
|
|
|
|
|
|
+ "# 张量与向量\n",
|
|
|
"mat = torch.randn(10, 3, 4)\n",
|
|
"mat = torch.randn(10, 3, 4)\n",
|
|
|
"vec = torch.randn(4)\n",
|
|
"vec = torch.randn(4)\n",
|
|
|
"print((mat @ vec).shape)"
|
|
"print((mat @ vec).shape)"
|