《机器学习实战》9.2树回归之树剪枝(tree pruning)
搜索微信公众号:‘AI-ming3526’或者’计算机视觉这件小事’ 获取更多人工智能、机器学习干货
csdn:https://blog.csdn.net/baidu_31657889/
github:https://github.com/aimi-cn/AILearners
本文出现的所有代码,均可在github上下载,不妨来个Star把谢谢~:Github代码地址
一、引言
本篇文章将会根据上节的回归树的构建过程是否得当来引入的剪枝(tree pruning)技术。
二、树剪枝
一棵树如果结点过多,表明该模型可能对数据进行了“过拟合”。
通过降低树的复杂度来避免过拟合的过程称为剪枝(pruning)。上小节我们也已经提到,设置tolS和tolN就是一种预剪枝操作。另一种形式的剪枝需要使用测试集和训练集,称作后剪枝(postpruning)。本节将分析后剪枝的有效性,但首先来看一下预剪枝的不足之处。
2.1 预剪枝
预剪枝有一定的局限性,比如我们现在使用一个新的数据集。
数据集下载地址:数据集下载
用上节的代码绘制数据集看一下:
可以看到,对于这个数据集与我们使用的第一个数据集很相似,但是区别在于y的数量级差100倍,数据分布相似,因此构建出的树应该也是只有两个叶结点。但是我们使用默认tolS和tolN参数创建树,你会发现运行结果如下所示:
可以看到,构建出的树有很多叶结点。产生这个现象的原因在于,停止条件tolS对误差的数量级十分敏感。如果在选项中花费时间并对上述误差容忍度取平均值,或许也能得到仅有两个叶结点组成的树:
可以看到,将参数tolS修改为10000后,构建的树就是只有两个叶结点。然而,显然这个值,需要我们经过不断测试得来,显然通过不断修改停止条件来得到合理结果并不是很好的办法。事实上,我们常常甚至不确定到底需要寻找什么样的结果。因为对于一个很多维度的数据集,你也不知道构建的树需要多少个叶结点。
可见,预剪枝有很大的局限性。接下来,我们讨论后剪枝,即利用测试集来对树进行剪枝。由于不需要用户指定参数,后剪枝是一个更理想化的剪枝方法。
2.2 后剪枝
使用后剪枝方法需要将数据集分成测试集和训练集。首先指定参数,使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶结点,用测试集来判断这些叶结点合并是否能降低测试集误差。如果是的话就合并。
后剪枝 prune() 的伪代码如下:
基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算将当前两个叶节点合并后的误差
计算不合并的误差
如果合并会降低误差的话,就将叶节点合并
为了演示后剪枝,我们使用ex2.txt文件作为训练集,而使用的新数据集ex2test.txt文件作为测试集。
测试集下载地址:数据集下载
现在我们使用ex2.txt训练回归树,然后利用ex2test.txt对回归树进行剪枝。我们需要创建三个函数isTree()、getMean()、prune()。其中isTree()用于测试输入变量是否是一棵树,返回布尔类型的结果。换句话说,该函数用于判断当前处理的结点是否是叶结点。第二个函数getMean()是一个递归函数,它从上往下遍历树直到叶结点为止。如果找到两个叶结点则计算它们的平均值。该函数对树进行塌陷处理(即返回树平均值)。而第三个函数prune()则为后剪枝函数。创建treePruning.py 编写代码如下:
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : treePruning.py
@Time : 2019/08/05 21:47:48
@Author : xiao ming
@Version : 1.0
@Contact : xiaoming3526@gmail.com
@Desc : 树回归之后剪枝
@github : https://github.com/aimi-cn/AILearners
'''
# here put the import lib
import matplotlib.pyplot as plt
import numpy as np
'''
@description: 加载数据
@param: fileName - 文件名
@return: dataMat - 数据矩阵
'''
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine)) #转化为float类型
dataMat.append(fltLine)
return dataMat
'''
@description: 根据特征切分数据集合
@param: dataSet - 数据集合
feature - 带切分的特征
value - 该特征的值
@return: mat0 - 切分的数据集合0
mat1 - 切分的数据集合1
'''
def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]
mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
return mat0, mat1
'''
@description: 生成叶结点
@param: dataSet - 数据集合
@return: 目标变量的均值
'''
def regLeaf(dataSet):
return np.mean(dataSet[:,-1])
'''
@description: 误差估计函数
@param: dataSet - 数据集合
@return: 目标变量的总方差
'''
def regErr(dataSet):
return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]
'''
@description: 找到数据的最佳二元切分方式函数
@param: dataSet - 数据集合
leafType - 生成叶结点
regErr - 误差估计函数
ops - 用户定义的参数构成的元组
@return: bestIndex - 最佳切分特征
bestValue - 最佳特征值
'''
def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops = (1,4)):
import types
#tolS允许的误差下降值,tolN切分的最少样本数
tolS = ops[0]; tolN = ops[1]
#如果当前所有值相等,则退出。(根据set的特性)
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
#统计数据集合的行m和列n
m, n = np.shape(dataSet)
#默认最后一个特征为最佳切分特征,计算其误差估计
S = errType(dataSet)
#分别为最佳误差,最佳特征切分的索引值,最佳特征值
bestS = float('inf'); bestIndex = 0; bestValue = 0
#遍历所有特征列
for featIndex in range(n - 1):
#遍历所有特征值
for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]):
#根据特征和特征值切分数据集
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
#如果数据少于tolN,则退出
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
#计算误差估计
newS = errType(mat0) + errType(mat1)
#如果误差估计更小,则更新特征索引值和特征值
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#如果误差减少不大则退出
if (S - bestS) < tolS:
return None, leafType(dataSet)
#根据最佳的切分特征和特征值切分数据集合
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
#如果切分出的数据集很小则退出
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
return None, leafType(dataSet)
#返回最佳切分特征和特征值
return bestIndex, bestValue
'''
@description: 树构建函数
@param: dataSet - 数据集合
leafType - 建立叶结点的函数
errType - 误差计算函数
ops - 包含树构建所有其他参数的元组
@return: retTree - 构建的回归树
'''
def createTree(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)):
#选择最佳切分特征和特征值
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
#r如果没有特征,则返回特征值
if feat == None: return val
#回归树
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
#分成左数据集和右数据集
lSet, rSet = binSplitDataSet(dataSet, feat, val)
#创建左子树和右子树
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
'''
@description: 判断测试输入变量是否是一棵树
@param: obj - 测试对象
@return: 是否是一棵树
'''
def isTree(obj):
import types
return (type(obj).__name__ == 'dict')
'''
@description: 对树进行塌陷处理(即返回树平均值)
@param: tree - 树
@return: 树的平均值
'''
def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right']) / 2.0
'''
@description: 后剪枝
@param: tree - 树
test - 测试集
@return: 树的平均值
'''
def prune(tree, testData):
#如果测试集为空,则对树进行塌陷处理
if np.shape(testData)[0] == 0: return getMean(tree)
#如果有左子树或者右子树,则切分数据集
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
#处理左子树(剪枝)
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
#处理右子树(剪枝)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
#如果当前结点的左右结点为叶结点
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
#计算没有合并的误差
errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'],2)) + np.sum(np.power(rSet[:,-1] - tree['right'],2))
#计算合并的均值
treeMean = (tree['left'] + tree['right']) / 2.0
#计算合并的误差
errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2))
#如果合并的误差小于没有合并的误差,则合并
if errorMerge < errorNoMerge:
return treeMean
else: return tree
else: return tree
if __name__ == '__main__':
print('\n剪枝前:')
train_filename = 'C:\\Users\\Administrator\\Desktop\\blog\\github\\AILearners\\data\\ml\\jqxxsz\\9.RegTrees\\ex2.txt'
train_Data = loadDataSet(train_filename)
train_Mat = np.mat(train_Data)
tree = createTree(train_Mat)
print(tree)
print('\n剪枝后:')
test_filename = 'C:\\Users\\Administrator\\Desktop\\blog\\github\\AILearners\\data\\ml\\jqxxsz\\9.RegTrees\\ex2test.txt'
test_Data = loadDataSet(test_filename)
test_Mat = np.mat(test_Data)
print(prune(tree, test_Mat))
运行结果如下如所示:
可以看到,树的大量结点已经被剪枝掉了,但没有像预期的那样剪枝成两部分,这说明后剪枝可能不如预剪枝有效。一般地,为了寻求最佳模型可以同时使用两种剪枝技术。
下节我们会讲一下模型数和树回归的一个项目案例-树回归与标准回归的比较。
AIMI-CN AI学习交流群【1015286623】 获取更多AI资料
扫码加群:
分享技术,乐享生活:我们的公众号计算机视觉这件小事每周推送“AI”系列资讯类文章,欢迎您的关注!
来源:https://blog.csdn.net/baidu_31657889/article/details/99684548