|
|
@@ -38,12 +38,12 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "def unbalanced_data(X, Y, zeroTimes):\n",
|
|
|
- " \"\"\"\n",
|
|
|
- " 通过将类别0的数据重复zeroTimes次,将均衡数据集变为非均衡数据集\n",
|
|
|
- " \"\"\"\n",
|
|
|
- " X0 = np.repeat(X[np.where(Y == 0)[0]], zeroTimes, axis=0)\n",
|
|
|
- " Y0 = np.repeat(Y[np.where(Y == 0)[0]], zeroTimes, axis=0)\n",
|
|
|
+ "def unbalanced_data(X, Y, zero_times):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 通过将类别0的数据重复zero_times次,将均衡数据集变为非均衡数据集\n",
|
|
|
+ " '''\n",
|
|
|
+ " X0 = np.repeat(X[np.where(Y == 0)[0]], zero_times, axis=0)\n",
|
|
|
+ " Y0 = np.repeat(Y[np.where(Y == 0)[0]], zero_times, axis=0)\n",
|
|
|
" X1 = X[np.where(Y > 0)[0]]\n",
|
|
|
" Y1 = Y[np.where(Y > 0)[0]]\n",
|
|
|
" _X = np.append(X0, X1, axis=0)\n",
|
|
|
@@ -73,9 +73,9 @@
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"def evaluate_model(Y, pred):\n",
|
|
|
- " \"\"\"\n",
|
|
|
+ " '''\n",
|
|
|
" 评估模型效果,其中包括ACC,AUC以及预测结果中类别1的个数\n",
|
|
|
- " \"\"\"\n",
|
|
|
+ " '''\n",
|
|
|
" pred_positive = []\n",
|
|
|
" true_positive = []\n",
|
|
|
" aucs = []\n",
|
|
|
@@ -98,11 +98,11 @@
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"def visualize(ratios, pred_positive, true_positive, aucs, accuracies):\n",
|
|
|
- " \"\"\"\n",
|
|
|
+ " '''\n",
|
|
|
" 将模型结果可视化\n",
|
|
|
- " \"\"\"\n",
|
|
|
+ " '''\n",
|
|
|
" # 为在Matplotlib中显示中文,设置特殊字体\n",
|
|
|
- " plt.rcParams[\"font.sans-serif\"] = [\"SimHei\"]\n",
|
|
|
+ " plt.rcParams['font.sans-serif'] = ['SimHei']\n",
|
|
|
" # 正确显示负号\n",
|
|
|
" plt.rcParams['axes.unicode_minus'] = False\n",
|
|
|
" plt.rcParams.update({'font.size': 13})\n",
|
|
|
@@ -110,18 +110,18 @@
|
|
|
" fig = plt.figure(figsize=(12, 6), dpi=100)\n",
|
|
|
" # 在图形框里画两幅图\n",
|
|
|
" ax = fig.add_subplot(1, 2, 1)\n",
|
|
|
- " ax.plot(ratios, pred_positive, label=\"%s\" % \"预测结果里类别1的个数\")\n",
|
|
|
- " ax.plot(ratios, true_positive, \"k--\", label=\"%s\" % \"原始数据里类别1的个数\")\n",
|
|
|
+ " ax.plot(ratios, pred_positive, label='%s' % '预测结果里类别1的个数')\n",
|
|
|
+ " ax.plot(ratios, true_positive, 'k--', label='%s' % '原始数据里类别1的个数')\n",
|
|
|
" ax.set_xlim([0, 0.5])\n",
|
|
|
" ax.invert_xaxis()\n",
|
|
|
- " legend = plt.legend(shadow=True, loc=\"best\")\n",
|
|
|
+ " legend = plt.legend(shadow=True, loc='best')\n",
|
|
|
" ax1 = fig.add_subplot(1, 2, 2)\n",
|
|
|
- " ax1.plot(ratios, aucs, \"r\", label=\"%s\" % \"曲线下面积(AUC)\")\n",
|
|
|
- " ax1.plot(ratios, accuracies, \"k-.\", label=\"%s\" % \"准确度(ACC)\")\n",
|
|
|
+ " ax1.plot(ratios, aucs, 'r', label='%s' % '曲线下面积(AUC)')\n",
|
|
|
+ " ax1.plot(ratios, accuracies, 'k-.', label='%s' % '准确度(ACC)')\n",
|
|
|
" ax1.set_xlim([0, 0.5])\n",
|
|
|
" ax1.set_ylim([0.5, 1])\n",
|
|
|
" ax1.invert_xaxis()\n",
|
|
|
- " legend = plt.legend(shadow=True, loc=\"best\")\n",
|
|
|
+ " legend = plt.legend(shadow=True, loc='best')\n",
|
|
|
" return fig"
|
|
|
]
|
|
|
},
|