{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# 使用pandas读取数据\n",
"import pandas as pd\n",
"\n",
"\n",
"data_path = \"./data/adult.data\"\n",
"raw_data = pd.read_csv(data_path)\n",
"## 选取需要使用的列\n",
"cols = [\"workclass\", \"sex\", \"age\", \"education_num\",\n",
" \"capital_gain\", \"capital_loss\", \"hours_per_week\", \"label\"]\n",
"data = raw_data[cols]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" workclass | \n",
" sex | \n",
" age | \n",
" education_num | \n",
" capital_gain | \n",
" capital_loss | \n",
" hours_per_week | \n",
" label | \n",
" label_code | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" State-gov | \n",
" Male | \n",
" 39 | \n",
" 13 | \n",
" 2174 | \n",
" 0 | \n",
" 40 | \n",
" <=50K | \n",
" 0 | \n",
"
\n",
" \n",
" | 1 | \n",
" Self-emp-not-inc | \n",
" Male | \n",
" 50 | \n",
" 13 | \n",
" 0 | \n",
" 0 | \n",
" 13 | \n",
" <=50K | \n",
" 0 | \n",
"
\n",
" \n",
" | 2 | \n",
" Private | \n",
" Male | \n",
" 38 | \n",
" 9 | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" <=50K | \n",
" 0 | \n",
"
\n",
" \n",
" | 3 | \n",
" Private | \n",
" Male | \n",
" 53 | \n",
" 7 | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" <=50K | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" Private | \n",
" Female | \n",
" 28 | \n",
" 13 | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" <=50K | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" workclass sex age education_num capital_gain capital_loss \\\n",
"0 State-gov Male 39 13 2174 0 \n",
"1 Self-emp-not-inc Male 50 13 0 0 \n",
"2 Private Male 38 9 0 0 \n",
"3 Private Male 53 7 0 0 \n",
"4 Private Female 28 13 0 0 \n",
"\n",
" hours_per_week label label_code \n",
"0 40 <=50K 0 \n",
"1 13 <=50K 0 \n",
"2 40 <=50K 0 \n",
"3 40 <=50K 0 \n",
"4 40 <=50K 0 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 将label转换为可以运算的变量\n",
"data.loc[:, \"label_code\"] = pd.Categorical(data.label).codes\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[' Male' ' Female']\n",
"[' State-gov' ' Self-emp-not-inc' ' Private' ' Federal-gov' ' Local-gov'\n",
" ' ?' ' Self-emp-inc' ' Without-pay' ' Never-worked']\n"
]
}
],
"source": [
"print(data[\"sex\"].unique())\n",
"print(data[\"workclass\"].unique())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"显示sex, label交叉报表:\n",
"sex label \n",
" Female <=50K 9592\n",
" >50K 1179\n",
" Male <=50K 15128\n",
" >50K 6662\n",
"dtype: int64\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAADuCAYAAAAOR30qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAADPZJREFUeJzt3XlwlPUdx/HPZnfJxlwYAjRRzFURJDGBBiWx1XqVSOmYAWy9ULQeUy8Uj+pYNWO1Ok49K9UOSqXj2CriIEJEa3GUGjMUhAiBoCGEICEQMOYAErJH/6BGMoQqkN39Zvf9+ovN85ud70OG9zw8+zz7OAKBgAAA4RcT7gEAAAcQZAAwgiADgBEEGQCMIMgAYARBBgAjCDIAGEGQAcAIggwARriOZPHglNRA2omZQRoFweJwSMe5pb3dEjdmAqFXs3bVrkAgMPS71h1RkNNOzNTcJSuPfiqETdEI6ZOt4Z4CiE7FJzm2fJ91nLIAACMIMgAYQZABwAiCDABGEGQAMIIgA4ARBBkAjCDIAGAEQQYAIwgyABhBkAHACIIMAEYQZAAwgiADgBEEGQCMIMgAYARBBgAjCDIAGEGQAcAIggwARhBkADCCIAOAEQQZAIwgyABgBEEGACMIMgAYQZABwAiCDABGEGQAMIIgA4ARBBkAjCDIAGAEQQYAIwgyABhBkAHACIIMAEYQZAAwgiADgBEEGQCMIMgAYARBBgAjCDIAGEGQAcAIggwARhBkADCCIAOAEQQZAIwgyABgBEEGACMIMgAYQZABwAiCDABGEGQAMIIgA4ARBBkAjCDIAGAEQQYAIwgyABhBkAHACIIMAEYQZAAwgiADgBEEGQCMIMgAYARBBgAjCDIAGEGQAUMenjVDLz5ZppUfL9OU4szDrrt9eknohkLIuMI9AIDeGhvqNDglVbGeuHCPghAjyIAxDodDOxq36gcnZKirs1MP3TZdg2I9yv1RkaZeeeMh68vnz9O6Tz9Ry+6d+t0TLys+MSkMU6M/cMoCMMbldsvbvf9/rwKa/KtrNDJ3rP79z0V9rn/tpaeUkDRYnrh41X1eHbpB0e8IMmDMCRk/1J6ONknS+jUr9P6if2j8j89X3HHxfa5PTD5eN977mKZedZMSk48P5ajoZ5yyAIzJHjlGTqdTzU3blHT8EHW0t2rBvNn6+qtdfa4vmTJd9990iXxer+7749wQT4v+5AgEAt978ejTCgNzl6wM4jgIlqIR0idbwz0FEJ2KT3KsCgQChd+1jlMWAGAEQQaM8fl84R4BYUKQAUMaGzbrb7Mf7XPbnTMm69G7r9XsR+6WJP3r7dd1x1WTel5/c7PInx+9R8sWzw/NwOhXfKgHGFFfW6MPlszXVbfcpw/fXajqTyt7tl1y7e0anj5Cd/3heUlSIBDQa3Of1l/e/FivPP+46mtrJEmrKj7Q3j3tOnfyxWHZBxwbggwY8MX6Kq346D1dPfN+SdLZE0t19sTSnu0bqlbqiw1VKrv1cmWfkqdpM27W8LQRcjgcOnlMgbbV18rv8+mh26brhTc/Dtdu4BgRZMAAt3uQvN5u+f1+xcTE6IPyBVr36Sc92y+/4S796e/LFOvx6K6rf6FLrr1dPp9XkuTzerWno01+v1+XXDdLr899RjMfeDJcu4JjQJABAzJPHq3YuOM054kHdM3MB3TOpKk6Z9LUnu2rKz9U8vGpyhp5qjraW+VwOCQ55PV6tXHtKhWeeZ5cbrcuvW6WZl52gXbt2K7U4Wnh2yEcFT7UA4xIOzFDF8+4Ra+88Pgh207KPkXPPXKn7r1+iiaWXi73oEGaNuNm3X/jL/XVrh3KKyzuWXvFb3572A8GYRs3hkQJbgwBwocbQwBggCHIAGAEQQYAIwgyABhBkAHACIIMAEZE9I0hLz5ZpurVlUpOSVVScopmPfTsUb/PhJ+WKHfchH6eEAC+FdFBlqRf315GSAEMCBEf5IN5vV49XTZT3u5upQ5PV9qITH38/tvauf1LDUs7UVOuvEm1G6pU89lK+f1+PfjMK3I6nb3eY/uXW/TSU2Xq3LtHF069UmeePzlMewMg0kT8OeS/Pvt7PTxrhrZu/kIrPnpP27ZsUtLgFG2p3SBJKj5vsgbFevTzi6/W1rrPlTuuSMXnTVZD3Ubt2tF4yPu9OW+2XC63hqWP0PqqFaHeHQARLOKPkK++9f6eUxYNdZ9r/E8u0GXX36El81+WJMV64hTrietZ//xj9+ia2x7UqLxC+ft4ckNAAZVefoOyRo7RR+8tDMk+AIgOEX+EfLAzzvqZatdXqWzmFdq9s6nPNUOGpWn5e29pw2f/UVtryyHbp0y/UfOee0QP3nKpkganBHtkAFGELxeKEny5EBA+fLnQYfAASQBWRVWQeYAkAMsi/kO9b/AASQDWRUWQeYAkgIEgKoLMAyQBDARREWQeIAlgIIiaD/V4gCQA67gOOUpwHTIQPlyHDAADDEEGACMIMgAYQZABwAiCDABGEGQAMIIgA4ARBBkAjCDIAGAEQQYAIwgyABhBkAHACIIMAEYQZAAwgiADgBEEGQCMIMgAYARBBgAjCDIAGEGQAcAIggwARhBkADCCIAOAEQQZAIwgyABgBEEGACMIMgAYQZABwAiCDABGEGQAMIIgA4ARBBkAjCDIAGAEQQYAIwgyABhBkAHACIIMAEYQZAAwgiADgBEEGQCMIMgAYARBBgAjCDIAGEGQAcAIggwARhDkKOLmtw2Y5gr3AAidxXPK1NHREe4xABwGx0xRhBgDthFkADCCIAOAEQQZAIwgyABgBEEGACMIMgAYQZABwAiCDABGEGQAMIIgA4ARBBkAjCDIAGAEQQYAIwgyABhBkAHACIIMAEYQZAAwgiADgBEEGQCMIMgAYARBBgAjCDIAGEGQAcAIggwARhBkwJClS5eqoqJCDQ0NmjNnzmHXLViwIIRTIVRc4R4AQG+tra2Ki4uTy8U/z2jDbxwwqL29XUlJSfJ6vXrnnXfkdDqVnp6ugoKCQ9ZWV1ersbFRe/fuVUlJiWJjY8MwMfoDpywAY2JiYuTz+SRJgUBAubm5GjZsmDZt2tTn+lWrVik2NlZut1u7d+8O5ajoZwQZMGbw4MHav3+/JKmpqUk1NTXKyMiQ2+3uc73H49FZZ52lgoICjo4HOIIMGJOamqqUlBRJUlxcnLq6urRmzRrt27evz/WjR4/W4sWLtXLlSiUkJIRyVPQzziEDhpSUlEiScnJyNH78eElSaWnpIeumTp3a8+e8vDzl5eWFZkAEFUfIAGAEQQYM8fv94R4BYUSQAUMaGhq0YcOGXj979dVX9e6772rFihWSDlxVsXDhwp7X39wkUl5ersbGxtAOjH5FkAFDMjMz5fF4tG7dOkkHbhLJysrSxIkTdfrpp6uzs1P19fUqLS1Ve3t7z9UY69atU2pqqtLT08M5Po4RQQaMycrKUnJystasWaMdO3Zo27ZtWrRokaqrq9XR0aGhQ4dKOnA1Rltbm/bt26fly5dr7NixYZ4cx4qrLACDfD6fXC6XsrKylJOTI6fTqTfeeEPp6ek955n9fr+6urrk9/uVn5+vtWvXaty4cWGeHMeCI2TAmI0bN2r//v3Kzc1VXV2durq65PP55PV6lZycrLa2NknSzp07lZSUpPj4eJ1xxhmqqalRd3d3mKfHseAIGTBk8+bNio2NVWZmpiRpyJAhKi8vl9PpVGFhoWJiYjRy5Ei99dZbGjp0qBITEyVJTqdTubm5qqqqUmFhYRj3AMeCIAOGZGVl9XqdmpqqadOm9frZqFGjNGrUqJ7X39wkctpppwV/QATVEZ2ycDiCNQaCyc2JKWBAOKIj5K+bG7XgmTuDNQuCJCEhQYVlZeEeA8B3OKJjJ+4iGpg6OjrCPQKA74H/zAKAEQQZAIwgyABgRERf9lZRUaHt27crLi5OHo9H55577lG/T2ZmJt8TACCoIjrIklRUVERIAQwIER/kg/n9fi1btkx+v18JCQlKSkpSXV2d2tvblZiYqIKCAjU3N6upqUmBQECTJk1STEzvszptbW2qqKhQd3e3Tj31VOXk5IRpbwBEmog/h1xZWamlS5eqpaVF9fX1am1tlcfj6Xk6b3Z2tlwul8aMGaOWlhalpaUpOztbLS0tfV4utnr1asXExCgxMVFNTU2h3h0AESzij5AnTJjQc8qipaVFGRkZKiws7Pm+WZfLJZfr27+G5cuXq6ioSMOHD1cgEOjzPfPz8zVkyBDV1tYGfwcARI2IP0I+WEZGhpqbm1VeXq49e/b0uSY+Pl6bNm1SU1OTOjs7D9leUFCgyspKLVmyRB6PJ9gjA4giEX2EXFxc3Ou10+nUhRdeeMi6g7+oJT8///++T3Jysi666KJ+nBIADoiqI2Ru/QZgWVQFmQdIArAsqoLMAyQBWBZVQZZ4gCQAuyL6Q73D4QGSACyKuiNkHiAJwKqoOkLmAZIALIuqIPMASQCWRd0pCwCwiiADgBEEGQCMcBzuG836XOxwNEvaErxxACAiZQQCgaHfteiIggwACB5OWQCAEQQZAIwgyABgBEEGACMIMgAYQZABwAiCDABGEGQAMIIgA4AR/wUwIyjJHdS22AAAAABJRU5ErkJggg==\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 利用交叉报表,直观了解数据\n",
"import matplotlib.pyplot as plt\n",
"from statsmodels.graphics.mosaicplot import mosaic\n",
"\n",
"\n",
"# 计算sex, label交叉报表\n",
"cross1 = pd.crosstab(data[\"sex\"], data[\"label\"])\n",
"print(\"显示sex, label交叉报表:\")\n",
"print(cross1.stack())\n",
"# 将交叉报表图形化\n",
"props = lambda key: {\"color\": \"0.45\"} if ' >50K' in key else {\"color\": \"#C6E2FF\"}\n",
"mosaic(cross1[[\" >50K\", \" <=50K\"]].stack(), properties=props, axes_label=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# 将数据分为训练集和测试集\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"\n",
"train_set, test_set = train_test_split(data, test_size=0.2, random_state=2111)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: Maximum number of iterations has been exceeded.\n",
" Current function value: 0.405259\n",
" Iterations: 35\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda3/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
" \"Check mle_retvals\", ConvergenceWarning)\n"
]
}
],
"source": [
"# 训练模型\n",
"import statsmodels.api as sm\n",
"\n",
"\n",
"c_formula = \"label_code ~ C(sex) + C(workclass) + education_num + capital_gain + capital_loss + hours_per_week\"\n",
"c_model = sm.Logit.from_formula(c_formula, data=train_set)\n",
"c_model = c_model.fit()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Logit Regression Results \n",
"==============================================================================\n",
"Dep. Variable: label_code No. Observations: 26048\n",
"Model: Logit Df Residuals: 26034\n",
"Method: MLE Df Model: 13\n",
"Date: Sun, 02 Jun 2019 Pseudo R-squ.: 0.2648\n",
"Time: 20:13:33 Log-Likelihood: -10556.\n",
"converged: False LL-Null: -14359.\n",
" LLR p-value: 0.000\n",
"=====================================================================================================\n",
" coef std err z P>|z| [0.025 0.975]\n",
"-----------------------------------------------------------------------------------------------------\n",
"Intercept -7.6282 0.145 -52.505 0.000 -7.913 -7.343\n",
"C(sex)[T. Male] 1.2558 0.045 27.958 0.000 1.168 1.344\n",
"C(workclass)[T. Federal-gov] 1.0879 0.132 8.223 0.000 0.829 1.347\n",
"C(workclass)[T. Local-gov] 0.6188 0.117 5.288 0.000 0.389 0.848\n",
"C(workclass)[T. Never-worked] -13.7636 1825.543 -0.008 0.994 -3591.762 3564.234\n",
"C(workclass)[T. Private] 0.5003 0.101 4.944 0.000 0.302 0.699\n",
"C(workclass)[T. Self-emp-inc] 1.2942 0.129 10.000 0.000 1.041 1.548\n",
"C(workclass)[T. Self-emp-not-inc] 0.3986 0.116 3.443 0.001 0.172 0.625\n",
"C(workclass)[T. State-gov] 0.4011 0.130 3.084 0.002 0.146 0.656\n",
"C(workclass)[T. Without-pay] -14.3432 1052.727 -0.014 0.989 -2077.649 2048.963\n",
"education_num 0.3286 0.008 41.730 0.000 0.313 0.344\n",
"capital_gain 0.0003 1.08e-05 30.013 0.000 0.000 0.000\n",
"capital_loss 0.0007 3.64e-05 19.958 0.000 0.001 0.001\n",
"hours_per_week 0.0292 0.002 19.288 0.000 0.026 0.032\n",
"=====================================================================================================\n"
]
}
],
"source": [
"# 展示模型结果\n",
"print(c_model.summary())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0., 0.],\n",
" [0., 0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 0., 1.]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 去掉不显著的虚拟变量\n",
"import numpy as np\n",
"\n",
"\n",
"# 定义workclass的类别顺序,数组里的第一个值为基准类别\n",
"l = [\" ?\", \" Never-worked\", \" Without-pay\", \" State-gov\",\n",
" \" Self-emp-not-inc\", \" Private\", \" Federal-gov\",\n",
" \" Local-gov\", \" Self-emp-inc\"]\n",
"# 定义各个类别对应的虚拟变量\n",
"contrast = np.eye(9, 6, k=-3)\n",
"contrast"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ContrastMatrix(array([[0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0., 0.],\n",
" [0., 0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 0., 1.]]),\n",
" [' State-gov',\n",
" ' Self-emp-not-inc',\n",
" ' Private',\n",
" ' Federal-gov',\n",
" ' Local-gov',\n",
" ' Self-emp-inc'])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 为每个虚拟变量命名\n",
"from patsy import ContrastMatrix\n",
"\n",
"\n",
"contrast_mat = ContrastMatrix(contrast, l[3:])\n",
"contrast_mat"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Optimization terminated successfully.\n",
" Current function value: 0.405321\n",
" Iterations 8\n"
]
}
],
"source": [
"# 将不显著的虚拟变量剔除,搭建模型\n",
"m_formula = \"\"\"label_code ~ C(workclass, contrast_mat, levels=l)\n",
" + C(sex) + education_num + capital_gain\n",
" + capital_loss + hours_per_week\"\"\"\n",
"m_model = sm.Logit.from_formula(m_formula, data=train_set)\n",
"m_model = m_model.fit()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Logit Regression Results \n",
"==============================================================================\n",
"Dep. Variable: label_code No. Observations: 26048\n",
"Model: Logit Df Residuals: 26036\n",
"Method: MLE Df Model: 11\n",
"Date: Sun, 02 Jun 2019 Pseudo R-squ.: 0.2647\n",
"Time: 20:13:41 Log-Likelihood: -10558.\n",
"converged: True LL-Null: -14359.\n",
" LLR p-value: 0.000\n",
"=========================================================================================================================\n",
" coef std err z P>|z| [0.025 0.975]\n",
"-------------------------------------------------------------------------------------------------------------------------\n",
"Intercept -7.6422 0.145 -52.624 0.000 -7.927 -7.358\n",
"C(workclass, contrast_mat, levels=l) State-gov 0.4149 0.130 3.193 0.001 0.160 0.670\n",
"C(workclass, contrast_mat, levels=l) Self-emp-not-inc 0.4127 0.116 3.569 0.000 0.186 0.639\n",
"C(workclass, contrast_mat, levels=l) Private 0.5143 0.101 5.090 0.000 0.316 0.712\n",
"C(workclass, contrast_mat, levels=l) Federal-gov 1.1018 0.132 8.335 0.000 0.843 1.361\n",
"C(workclass, contrast_mat, levels=l) Local-gov 0.6327 0.117 5.412 0.000 0.404 0.862\n",
"C(workclass, contrast_mat, levels=l) Self-emp-inc 1.3084 0.129 10.117 0.000 1.055 1.562\n",
"C(sex)[T. Male] 1.2554 0.045 27.948 0.000 1.167 1.343\n",
"education_num 0.3286 0.008 41.738 0.000 0.313 0.344\n",
"capital_gain 0.0003 1.08e-05 30.018 0.000 0.000 0.000\n",
"capital_loss 0.0007 3.64e-05 19.964 0.000 0.001 0.001\n",
"hours_per_week 0.0292 0.002 19.284 0.000 0.026 0.032\n",
"=========================================================================================================================\n"
]
}
],
"source": [
"# 展示模型结果\n",
"print(m_model.summary())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Optimization terminated successfully.\n",
" Current function value: 0.426517\n",
" Iterations 8\n"
]
}
],
"source": [
"# 搭建不使用类别变量的模型\n",
"b_formula = \"\"\"label_code ~ education_num + capital_gain\n",
" + capital_loss + hours_per_week\"\"\"\n",
"b_model = sm.Logit.from_formula(b_formula, data=train_set)\n",
"b_model = b_model.fit()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" workclass | \n",
" sex | \n",
" age | \n",
" education_num | \n",
" capital_gain | \n",
" capital_loss | \n",
" hours_per_week | \n",
" label | \n",
" label_code | \n",
" b_prob | \n",
" m_prob | \n",
"
\n",
" \n",
" \n",
" \n",
" | 1704 | \n",
" ? | \n",
" Female | \n",
" 22 | \n",
" 10 | \n",
" 0 | \n",
" 0 | \n",
" 35 | \n",
" <=50K | \n",
" 0 | \n",
" 0.127790 | \n",
" 0.034448 | \n",
"
\n",
" \n",
" | 1376 | \n",
" Self-emp-not-inc | \n",
" Male | \n",
" 44 | \n",
" 9 | \n",
" 0 | \n",
" 0 | \n",
" 55 | \n",
" <=50K | \n",
" 0 | \n",
" 0.188333 | \n",
" 0.196298 | \n",
"
\n",
" \n",
" | 14634 | \n",
" Private | \n",
" Female | \n",
" 39 | \n",
" 10 | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" <=50K | \n",
" 0 | \n",
" 0.151123 | \n",
" 0.064593 | \n",
"
\n",
" \n",
" | 21554 | \n",
" Private | \n",
" Female | \n",
" 29 | \n",
" 13 | \n",
" 0 | \n",
" 0 | \n",
" 45 | \n",
" >50K | \n",
" 1 | \n",
" 0.360671 | \n",
" 0.176411 | \n",
"
\n",
" \n",
" | 20959 | \n",
" Private | \n",
" Female | \n",
" 43 | \n",
" 9 | \n",
" 0 | \n",
" 0 | \n",
" 44 | \n",
" <=50K | \n",
" 0 | \n",
" 0.131304 | \n",
" 0.052917 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" workclass sex age education_num capital_gain \\\n",
"1704 ? Female 22 10 0 \n",
"1376 Self-emp-not-inc Male 44 9 0 \n",
"14634 Private Female 39 10 0 \n",
"21554 Private Female 29 13 0 \n",
"20959 Private Female 43 9 0 \n",
"\n",
" capital_loss hours_per_week label label_code b_prob m_prob \n",
"1704 0 35 <=50K 0 0.127790 0.034448 \n",
"1376 0 55 <=50K 0 0.188333 0.196298 \n",
"14634 0 40 <=50K 0 0.151123 0.064593 \n",
"21554 0 45 >50K 1 0.360671 0.176411 \n",
"20959 0 44 <=50K 0 0.131304 0.052917 "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 得到预测结果\n",
"test_set.loc[:, \"b_prob\"] = b_model.predict(test_set)\n",
"test_set.loc[:, \"m_prob\"] = m_model.predict(test_set)\n",
"test_set.head()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# ROC曲线\n",
"from sklearn import metrics\n",
"\n",
"# 计算b_model的预测效果\n",
"b_fpr, b_tpr, _ = metrics.roc_curve(test_set[\"label_code\"], test_set[\"b_prob\"])\n",
"b_auc = metrics.auc(b_fpr, b_tpr)\n",
"# 计算m_model的预测效果\n",
"m_fpr, m_tpr, _ = metrics.roc_curve(test_set[\"label_code\"], test_set[\"m_prob\"])\n",
"m_auc = metrics.auc(m_fpr, m_tpr)\n",
"# 画图\n",
"fig = plt.figure(figsize=(6, 6), dpi=80)\n",
"# 在图形框里只画一幅图\n",
"ax = fig.add_subplot(1, 1, 1)\n",
"ax.plot(b_fpr, b_tpr, \"k\",\n",
" label=\"%s; %s = %0.4f\" % (\"未使用定性变量的ROC曲线\", \"曲线下面积(AUC)\", b_auc))\n",
"ax.plot(m_fpr, m_tpr, \"b-.\",\n",
" label=\"%s; %s = %0.4f\" % (\"使用定性变量的ROC曲线\", \"曲线下面积(AUC)\", m_auc))\n",
"ax.set_xlim([0, 1])\n",
"ax.set_ylim([0, 1])\n",
"legend = plt.legend(shadow=True)\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}