目录

决策树3-泰坦尼克号实例(网格搜索,超参数调参)

目录

前言

数据集来自kaggle

链接:https://www.kaggle.com/c/titanic/data

里面的test和train的csv数据集为所需数据集。

代码

1
2
3
4
5
6
7
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier  #分类器 只能分类数字
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
1
2
data = pd.read_csv("train.csv")
data.head()

PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
0 1 0 3 Braund, Mr. Owen Harris male 22.0 1 0 A/5 21171 7.2500 NaN S
1 2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 0 PC 17599 71.2833 C85 C
2 3 1 3 Heikkinen, Miss. Laina female 26.0 0 0 STON/O2. 3101282 7.9250 NaN S
3 4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 0 113803 53.1000 C123 S
4 5 0 3 Allen, Mr. William Henry male 35.0 0 0 373450 8.0500 NaN S
1
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
1
2
3
4
5
6
#筛选特征
data.drop(["Name","Ticket","Cabin"],inplace=True,axis=1)
#处理缺失值
data["Age"] = data["Age"].fillna(data["Age"].mean())
#删除缺失值少的行
data = data.dropna()
1
data.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 889 entries, 0 to 890
Data columns (total 9 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  889 non-null    int64  
 1   Survived     889 non-null    int64  
 2   Pclass       889 non-null    int64  
 3   Sex          889 non-null    object 
 4   Age          889 non-null    float64
 5   SibSp        889 non-null    int64  
 6   Parch        889 non-null    int64  
 7   Fare         889 non-null    float64
 8   Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(2)
memory usage: 69.5+ KB
1
labels = data["Embarked"].unique().tolist()
1
2
3
#转换多分类为数值
data["Embarked"] = data["Embarked"].apply(lambda x: labels.index(x))
#data["Sex"] = data["Sex"].apply(lambda x: labels.index(x))
1
2
3
#转换二分类为01变量
data["Sex"] = (data["Sex"] == "male").astype("int")
(data["Sex"] == "male").astype("int")
0      0
1      0
2      0
3      0
4      0
      ..
886    0
887    0
888    0
889    0
890    0
Name: Sex, Length: 889, dtype: int64
1
data

PassengerId Survived Pclass Sex Age SibSp Parch Fare Embarked
0 1 0 3 1 22.000000 1 0 7.2500 0
1 2 1 1 0 38.000000 1 0 71.2833 1
2 3 1 3 0 26.000000 0 0 7.9250 0
3 4 1 1 0 35.000000 1 0 53.1000 0
4 5 0 3 1 35.000000 0 0 8.0500 0
... ... ... ... ... ... ... ... ... ...
886 887 0 2 1 27.000000 0 0 13.0000 0
887 888 1 1 0 19.000000 0 0 30.0000 0
888 889 0 3 0 29.699118 1 2 23.4500 0
889 890 1 1 1 26.000000 0 0 30.0000 1
890 891 0 3 1 32.000000 0 0 7.7500 2

889 rows × 9 columns

1
2
x = data.iloc[:,data.columns != "Survived"]
x

PassengerId Pclass Sex Age SibSp Parch Fare Embarked
0 1 3 1 22.000000 1 0 7.2500 0
1 2 1 0 38.000000 1 0 71.2833 1
2 3 3 0 26.000000 0 0 7.9250 0
3 4 1 0 35.000000 1 0 53.1000 0
4 5 3 1 35.000000 0 0 8.0500 0
... ... ... ... ... ... ... ... ...
886 887 2 1 27.000000 0 0 13.0000 0
887 888 1 0 19.000000 0 0 30.0000 0
888 889 3 0 29.699118 1 2 23.4500 0
889 890 1 1 26.000000 0 0 30.0000 1
890 891 3 1 32.000000 0 0 7.7500 2

889 rows × 8 columns

1
2
y = data.iloc[:,data.columns == "Survived"]
y

Survived
0 0
1 1
2 1
3 1
4 0
... ...
886 0
887 1
888 0
889 1
890 0

889 rows × 1 columns

1
Xtrain,Xtest,ytrain,ytest = train_test_split(x,y,test_size=0.3)
1
2
3
4
Xtrain.index = range(Xtrain.shape[0])
#重构index
#Xtrain.reset_index(drop=True,inplace=True)
#会多一列index列,不好用
1
2
for i in [Xtrain,Xtest,ytrain,ytest]:
    i.index = range(i.shape[0])
1
2
3
4
clf = DecisionTreeClassifier(random_state=25)
clf = clf.fit(Xtrain,ytrain)
score = clf.score(Xtest,ytest)
score
0.7602996254681648
1
2
3
4
clf = DecisionTreeClassifier(random_state=25)
#别划分测试和训练集,自动划分
score = cross_val_score(clf,x,y,cv=10).mean()
score
0.7469611848825333
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# 交叉验证
tr = []#训练集分数
te = []#交叉验证分数
for i in range(10):
    clf = DecisionTreeClassifier(random_state=25
                                ,max_depth=i+1
                                ,criterion="entropy"#熵,欠拟合才用
                                )
    clf = clf.fit(Xtrain,ytrain)
    score_tr = clf.score(Xtrain,ytrain)
    score_te = cross_val_score(clf,x,y,cv=10).mean()
    tr.append(score_tr)
    te.append(score_te)
print(max(te))
0.8166624106230849
1
2
3
4
5
plt.plot(range(1,11),tr,color="red",label="train")
plt.plot(range(1,11),te,color="blue",label="test")
plt.xticks(range(1,11))
plt.legend()
plt.show()

https://gitee.com/spiritlhl/picture/raw/master/output_18_0.png

网格搜索 同调多参数

1
[*range(0,50,5)]
[0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
#min_impurity_decrease 
gini_thresholds = np.linspace(0,0.5,20)#线性插值取20个数在0~0.5中

#entropy_threholds = np.linspace(0,1,50)#线性插值取50个数在0~1中

#parameters 本质是一串参数和这些参数对应的,希望网格搜索来搜索的参数的取值范围
parameters = {"criterion":("gini","entropy")#基尼 信息熵
              ,"splitter":("best","random")
              ,"max_depth":[*range(1,10)]#1~更多
              ,"min_samples_leaf":[*range(1,50,5)]#1~更多
              ,"min_impurity_decrease":[*np.linspace(0,0.5,20)]#看选基尼还是信息熵,区间变化
}

clf = DecisionTreeClassifier(random_state=25)#决策树分类器 随机稳定25
GS = GridSearchCV(clf,parameters,cv=10)
GS = GS.fit(Xtrain,ytrain)
1
GS.best_params_ #从我们输入的参数和参数取值的列表中,返回最佳组合
{'criterion': 'entropy',
 'max_depth': 8,
 'min_impurity_decrease': 0.0,
 'min_samples_leaf': 1,
 'splitter': 'random'}
1
GS.best_score_ #网格搜索后的模型的评判标准
0.8200204813108039