目录

决策树2-简单调参

目录

代码

1
2
3
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
1
wine = load_wine()
1
wine.data.shape
(178, 13)
1
wine.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2])
1
2
import pandas as pd
pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)

0 1 2 3 4 5 6 7 8 9 10 11 12 0
0 14.23 1.71 2.43 15.6 127.0 2.80 3.06 0.28 2.29 5.64 1.04 3.92 1065.0 0
1 13.20 1.78 2.14 11.2 100.0 2.65 2.76 0.26 1.28 4.38 1.05 3.40 1050.0 0
2 13.16 2.36 2.67 18.6 101.0 2.80 3.24 0.30 2.81 5.68 1.03 3.17 1185.0 0
3 14.37 1.95 2.50 16.8 113.0 3.85 3.49 0.24 2.18 7.80 0.86 3.45 1480.0 0
4 13.24 2.59 2.87 21.0 118.0 2.80 2.69 0.39 1.82 4.32 1.04 2.93 735.0 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
173 13.71 5.65 2.45 20.5 95.0 1.68 0.61 0.52 1.06 7.70 0.64 1.74 740.0 2
174 13.40 3.91 2.48 23.0 102.0 1.80 0.75 0.43 1.41 7.30 0.70 1.56 750.0 2
175 13.27 4.28 2.26 20.0 120.0 1.59 0.69 0.43 1.35 10.20 0.59 1.56 835.0 2
176 13.17 2.59 2.37 20.0 120.0 1.65 0.68 0.53 1.46 9.30 0.60 1.62 840.0 2
177 14.13 4.10 2.74 24.5 96.0 2.05 0.76 0.56 1.35 9.20 0.61 1.60 560.0 2

178 rows × 14 columns

1
wine.feature_names
['alcohol',
 'malic_acid',
 'ash',
 'alcalinity_of_ash',
 'magnesium',
 'total_phenols',
 'flavanoids',
 'nonflavanoid_phenols',
 'proanthocyanins',
 'color_intensity',
 'hue',
 'od280/od315_of_diluted_wines',
 'proline']
1
wine.target_names
array(['class_0', 'class_1', 'class_2'], dtype='<U7')
1
2
#XXYY
Xtrain,Xtest,ytrain,ytest=train_test_split(wine.data,wine.target,test_size=0.3)
1
Xtrain.shape
(124, 13)
1
Xtest.shape
(54, 13)
1
ytrain
array([2, 0, 0, 1, 1, 2, 0, 2, 0, 0, 1, 0, 1, 2, 1, 2, 1, 0, 1, 1, 0, 1,
       1, 1, 0, 0, 0, 2, 0, 2, 1, 0, 2, 1, 1, 1, 1, 2, 1, 0, 2, 0, 0, 2,
       0, 2, 1, 2, 1, 1, 1, 0, 0, 0, 0, 1, 2, 0, 1, 1, 0, 1, 1, 0, 0, 0,
       1, 1, 2, 2, 0, 0, 1, 2, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 2, 1, 1, 0,
       2, 2, 2, 2, 1, 0, 1, 2, 0, 0, 1, 0, 0, 2, 1, 0, 0, 2, 2, 1, 2, 0,
       2, 1, 2, 1, 1, 1, 0, 1, 1, 1, 1, 2, 0, 0])
1
2
3
4
5
6
7
#通过特征计算不纯度进行分类
clf = tree.DecisionTreeClassifier(criterion="entropy",random_state=30,splitter="random")#不纯度
#random_state=30 设置随机数种子复现结果,设置随机参数
#splitter 设置为best 或random 一个设置更重要分支进行分枝,一个更随机防止过拟合
clf = clf.fit(Xtrain,ytrain)
score = clf.score(Xtest,ytest)#返回预测的准确度accuracy
score
0.8888888888888888
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
feature_names=wine.feature_names
#画树
import graphviz
dot_data = tree.export_graphviz(clf
                                ,feature_names=feature_names
                                ,class_names=["白酒","红酒","伏特加"]
                                ,filled=True #填充颜色
                                ,rounded=True#直角框变成圆角框
                                ,out_file=None
                                )
graph = graphviz.Source(dot_data)
graph

https://gitee.com/spiritlhl/picture/raw/master/output_13_0.svg

1
clf.feature_importances_
array([0.26346976, 0.        , 0.        , 0.        , 0.        ,
       0.02075703, 0.36896763, 0.        , 0.        , 0.03519618,
       0.048165  , 0.20016672, 0.0632777 ])
1
2
#特征重要性-元组列表
[*zip(feature_names,clf.feature_importances_)]
[('alcohol', 0.2634697570171739),
 ('malic_acid', 0.0),
 ('ash', 0.0),
 ('alcalinity_of_ash', 0.0),
 ('magnesium', 0.0),
 ('total_phenols', 0.020757027052013897),
 ('flavanoids', 0.36896762677034245),
 ('nonflavanoid_phenols', 0.0),
 ('proanthocyanins', 0.0),
 ('color_intensity', 0.035196177552898056),
 ('hue', 0.04816499533787735),
 ('od280/od315_of_diluted_wines', 0.20016671726273566),
 ('proline', 0.0632776990069589)]
1
2
scroe_train = clf.score(Xtrain,ytrain)
scroe_train
1.0

剪枝策略

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#限制深度 max_depth 3
#min_samples_left 分支满足条件才能向下生长 5~1 或者用百分比0.几
#通过特征计算不纯度进行分类
clf = tree.DecisionTreeClassifier(criterion="entropy"#不纯度
                                  ,random_state=30
                                  ,splitter="random"
                                  #,max_depth=3
                                  ,min_samples_leaf=10
                                   #,min_impurity_split=1
                                 )
#random_state=30 设置随机数种子复现结果,设置随机参数
#splitter 设置为best 或random 一个设置更重要分支进行分枝,一个更随机防止过拟合
clf = clf.fit(Xtrain,ytrain)
score = clf.score(Xtest,ytest)#返回预测的准确度accuracy
score
feature_names=wine.feature_names
#画树
import graphviz
dot_data = tree.export_graphviz(clf
                                ,feature_names=feature_names
                                ,class_names=["白酒","红酒","伏特加"]
                                ,filled=True #填充颜色
                                ,rounded=True#直角框变成圆角框
                                ,out_file=None
                                )
graph = graphviz.Source(dot_data)
graph

https://gitee.com/spiritlhl/picture/raw/master/output_18_0.svg

1
2
score = clf.score(Xtest,ytest)#返回预测的准确度accuracy
score
0.8518518518518519
1
2
#max_features 设置特征数量用几个
#min_impurity_decrease 限制信息增益 父子节点信息增益差
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
#画超参数曲线看打分确定参数
import matplotlib.pyplot as plt
test = []
for i in range(10):
    clf = tree.DecisionTreeClassifier(max_depth=i+1
                                      ,criterion="entropy"
                                      ,random_state=30
                                      ,splitter="random"
                                     )
    clf = clf.fit(Xtrain,ytrain)
    score = clf.score(Xtest,ytest)#返回预测的准确度accuracy
    test.append(score)
plt.plot(range(1,11),test,color="red",label="max_depth")
plt.legend()
plt.show()

https://gitee.com/spiritlhl/picture/raw/master/output_21_0.png

1
2
3
4
# 目标权重参数
#class_weight
#class_weight_fraction_leaf
#默认偏向平衡,偏向主导类,可设置为偏向少数类
1
2
#apply返回每个测试样本所在的叶子节点的索引
clf.apply(Xtest)
array([ 9,  4, 32, 16,  4,  4, 32, 22, 28, 32, 27, 16, 25, 16, 32, 16, 10,
       16,  4,  4, 32,  4, 16,  4, 25, 16, 16, 16, 32, 32,  6, 22, 10,  4,
        4,  4, 32, 29, 25, 13, 16,  4,  9, 25, 32,  4,  4, 22,  4, 22, 32,
       32, 22,  4])
1
2
#predict返回每个测试样本的分类/回归结果
clf.predict(Xtest)
array([1, 2, 0, 1, 2, 2, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 2, 1, 2, 2, 0, 2,
       1, 2, 1, 1, 1, 1, 0, 0, 1, 1, 2, 2, 2, 2, 0, 1, 1, 1, 1, 2, 1, 1,
       0, 2, 2, 1, 2, 1, 0, 0, 1, 2])
1
#特征维度起码是2维,如果是一维reshape(-1,1)来给数据增加维度

交叉验证

1
2
3
from sklearn.datasets import load_boston
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeRegressor
1
2
boston = load_boston()
boston.data
array([[6.3200e-03, 1.8000e+01, 2.3100e+00, ..., 1.5300e+01, 3.9690e+02,
        4.9800e+00],
       [2.7310e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9690e+02,
        9.1400e+00],
       [2.7290e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9283e+02,
        4.0300e+00],
       ...,
       [6.0760e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
        5.6400e+00],
       [1.0959e-01, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9345e+02,
        6.4800e+00],
       [4.7410e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
        7.8800e+00]])
1
boston.target
array([24. , 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15. ,
       18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6,
       15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21. , 12.7, 14.5, 13.2,
       13.1, 13.5, 18.9, 20. , 21. , 24.7, 30.8, 34.9, 26.6, 25.3, 24.7,
       21.2, 19.3, 20. , 16.6, 14.4, 19.4, 19.7, 20.5, 25. , 23.4, 18.9,
       35.4, 24.7, 31.6, 23.3, 19.6, 18.7, 16. , 22.2, 25. , 33. , 23.5,
       19.4, 22. , 17.4, 20.9, 24.2, 21.7, 22.8, 23.4, 24.1, 21.4, 20. ,
       20.8, 21.2, 20.3, 28. , 23.9, 24.8, 22.9, 23.9, 26.6, 22.5, 22.2,
       23.6, 28.7, 22.6, 22. , 22.9, 25. , 20.6, 28.4, 21.4, 38.7, 43.8,
       33.2, 27.5, 26.5, 18.6, 19.3, 20.1, 19.5, 19.5, 20.4, 19.8, 19.4,
       21.7, 22.8, 18.8, 18.7, 18.5, 18.3, 21.2, 19.2, 20.4, 19.3, 22. ,
       20.3, 20.5, 17.3, 18.8, 21.4, 15.7, 16.2, 18. , 14.3, 19.2, 19.6,
       23. , 18.4, 15.6, 18.1, 17.4, 17.1, 13.3, 17.8, 14. , 14.4, 13.4,
       15.6, 11.8, 13.8, 15.6, 14.6, 17.8, 15.4, 21.5, 19.6, 15.3, 19.4,
       17. , 15.6, 13.1, 41.3, 24.3, 23.3, 27. , 50. , 50. , 50. , 22.7,
       25. , 50. , 23.8, 23.8, 22.3, 17.4, 19.1, 23.1, 23.6, 22.6, 29.4,
       23.2, 24.6, 29.9, 37.2, 39.8, 36.2, 37.9, 32.5, 26.4, 29.6, 50. ,
       32. , 29.8, 34.9, 37. , 30.5, 36.4, 31.1, 29.1, 50. , 33.3, 30.3,
       34.6, 34.9, 32.9, 24.1, 42.3, 48.5, 50. , 22.6, 24.4, 22.5, 24.4,
       20. , 21.7, 19.3, 22.4, 28.1, 23.7, 25. , 23.3, 28.7, 21.5, 23. ,
       26.7, 21.7, 27.5, 30.1, 44.8, 50. , 37.6, 31.6, 46.7, 31.5, 24.3,
       31.7, 41.7, 48.3, 29. , 24. , 25.1, 31.5, 23.7, 23.3, 22. , 20.1,
       22.2, 23.7, 17.6, 18.5, 24.3, 20.5, 24.5, 26.2, 24.4, 24.8, 29.6,
       42.8, 21.9, 20.9, 44. , 50. , 36. , 30.1, 33.8, 43.1, 48.8, 31. ,
       36.5, 22.8, 30.7, 50. , 43.5, 20.7, 21.1, 25.2, 24.4, 35.2, 32.4,
       32. , 33.2, 33.1, 29.1, 35.1, 45.4, 35.4, 46. , 50. , 32.2, 22. ,
       20.1, 23.2, 22.3, 24.8, 28.5, 37.3, 27.9, 23.9, 21.7, 28.6, 27.1,
       20.3, 22.5, 29. , 24.8, 22. , 26.4, 33.1, 36.1, 28.4, 33.4, 28.2,
       22.8, 20.3, 16.1, 22.1, 19.4, 21.6, 23.8, 16.2, 17.8, 19.8, 23.1,
       21. , 23.8, 23.1, 20.4, 18.5, 25. , 24.6, 23. , 22.2, 19.3, 22.6,
       19.8, 17.1, 19.4, 22.2, 20.7, 21.1, 19.5, 18.5, 20.6, 19. , 18.7,
       32.7, 16.5, 23.9, 31.2, 17.5, 17.2, 23.1, 24.5, 26.6, 22.9, 24.1,
       18.6, 30.1, 18.2, 20.6, 17.8, 21.7, 22.7, 22.6, 25. , 19.9, 20.8,
       16.8, 21.9, 27.5, 21.9, 23.1, 50. , 50. , 50. , 50. , 50. , 13.8,
       13.8, 15. , 13.9, 13.3, 13.1, 10.2, 10.4, 10.9, 11.3, 12.3,  8.8,
        7.2, 10.5,  7.4, 10.2, 11.5, 15.1, 23.2,  9.7, 13.8, 12.7, 13.1,
       12.5,  8.5,  5. ,  6.3,  5.6,  7.2, 12.1,  8.3,  8.5,  5. , 11.9,
       27.9, 17.2, 27.5, 15. , 17.2, 17.9, 16.3,  7. ,  7.2,  7.5, 10.4,
        8.8,  8.4, 16.7, 14.2, 20.8, 13.4, 11.7,  8.3, 10.2, 10.9, 11. ,
        9.5, 14.5, 14.1, 16.1, 14.3, 11.7, 13.4,  9.6,  8.7,  8.4, 12.8,
       10.5, 17.1, 18.4, 15.4, 10.8, 11.8, 14.9, 12.6, 14.1, 13. , 13.4,
       15.2, 16.1, 17.8, 14.9, 14.1, 12.7, 13.5, 14.9, 20. , 16.4, 17.7,
       19.5, 20.2, 21.4, 19.9, 19. , 19.1, 19.1, 20.1, 19.9, 19.6, 23.2,
       29.8, 13.8, 13.3, 16.7, 12. , 14.6, 21.4, 23. , 23.7, 25. , 21.8,
       20.6, 21.2, 19.1, 20.6, 15.2,  7. ,  8.1, 13.6, 20.1, 21.8, 24.5,
       23.1, 19.7, 18.3, 21.2, 17.5, 16.8, 22.4, 20.6, 23.9, 22. , 11.9])
1
2
3
4
5
6
7
regressor = DecisionTreeRegressor(random_state=0)#实例化
cross_val_score(regressor      #模型
                ,boston.data   #数据
                ,boston.target #特征
                ,cv=10         #交叉验证次数
                ,scoring="neg_mean_squared_error" #返回负的均方误差,不写默认返回R方
               )
array([-18.08941176, -10.61843137, -16.31843137, -44.97803922,
       -17.12509804, -49.71509804, -12.9986    , -88.4514    ,
       -55.7914    , -25.0816    ])