请选择 进入手机版 | 继续访问电脑版

技术控

    今日:0| 主题:61300
收藏本版 (1)
最新软件应用技术尽在掌握

[其他] 树回归

[复制链接]
﹏戒蔔掉旳愛 发表于 2016-10-5 19:38:03
170 2
树回归
  优点:可以对复杂和非线性的数据建模
  缺点:结果不易理解
  适用数据类型: 数值型和标称型数据。
  CART算法实现

  binSplitDataSet()函数,有三个参数:数据集合,待切分的特征和该特征的某个值。在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回。
  1. def binSplitDataSet(dataSet, feature, value):
  2.     mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
  3.     mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
  4.     return mat0, mat1
复制代码
createTree()函数,有4个参数,数据集和其他三个可选参数,这些可选参数决定了树的类型: leafType给出了建立叶节点的函数,errorType代表误差计算函数,ops是一个包含树构建所需其它参数的元祖。
  1. def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
  2.     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
  3.     if feat == None: return val
  4.     retTree = {}
  5.     retTree['spInd'] = feat
  6.     retTree['spVal'] = val
  7.     lSet, rSet = binSplitDataSet(dataSet, feat, val)
  8.     retTree['left'] = createTree(lSet, leafType, errType, ops)
  9.     retTree['right'] = createTree(rSet, leafType, errType, ops)
  10.     return retTree
复制代码
chooseBestSplit()函数,给定某个误差计算方法,该函数会找到数据集上的最佳二元切分方式。该函数需要完成两件事:用最佳方式切分数据集和生成相应的叶节点。
  伪代码如下:
  1. 对每个特征:
  2.     对每个特征值:
  3.         将数据集切分成两份
  4.         计算切分的误差
  5.         如果当前误差小于当前最小误差,将当前切分设定为最佳切分并更新最小误差
  6. 返回最佳切分的特征和阀值
复制代码
coding:
  1. def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
  2.     tolS = ops[0]; tolN = ops[1]
  3.     #if all the target variables are the same value: quit and return value
  4.     if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
  5.         return None, leafType(dataSet)
  6.     m,n = shape(dataSet)
  7.     #the choice of the best feature is driven by Reduction in RSS error from mean
  8.     S = errType(dataSet)
  9.     bestS = inf; bestIndex = 0; bestValue = 0
  10.     for featIndex in range(n-1):
  11.         for splitVal in set(dataSet[:,featIndex]):
  12.             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
  13.             if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
  14.             newS = errType(mat0) + errType(mat1)
  15.             if newS < bestS:
  16.                 bestIndex = featIndex
  17.                 bestValue = splitVal
  18.                 bestS = newS
  19.     #if the decrease (S-bestS) is less than a threshold don't do the split
  20.     if (S - bestS) < tolS:
  21.         return None, leafType(dataSet) #exit cond 2
  22.     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
  23.     if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3
  24.         return None, leafType(dataSet)
  25.     return bestIndex,bestValue#returns the best feature to split on
  26.                               #and the value used for that split
复制代码
函数chooseBestSplit()一开始为ops设定了tolS和tolN这两个值,它们是用户指定的参数,用于控制函数的停止时机。其中tolS是容许的误差下降值,tolN是切分的最少样本数。通过对当前所有目标变量建立一个集合,然后统计不同剩余特征值的数目,如果该数目为1,那么不需要再切分直接返回,然后函数计算了当前数据集的大小和误差,该误差S将用于和新切分误差进行对比,检查新切分能否降低误差。 如果切分数据集后效果提升不够大,那么就不进行切分操作而直接创建叶节点。
  另外,还需要检验两个切分后的子集大小,如果某个子集大小小于用户定义的参数tolN,那么也不进行切分。
  运行结果:
   

树回归

树回归-1-技术控-切分,dataSet,误差,数据,tree

  后剪枝

  伪代码:
  1. 基于已有的树切分测试数据:
  2.     如果存在任一子集是一棵树,则在该子集递归剪枝过程
  3.     计算将当前两个节点合并后的误差
  4.     计算不合并的误差
  5.     如果合并会降低误差的话,就将叶节点合并。
复制代码
coding:
  1. def prune(tree, testData):
  2.     if shape(testData)[0] == 0: return getMean(tree)
  3.     if (isTree(tree['right']) or isTree(tree['left'])):
  4.         lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
  5.     if isTree(tree['left']): tree['left'] = prune(tree['left'],lSet)
  6.     if isTree(tree['right']): tree['right'] = prune(tree['right'],rSet)
  7.     if not isTree(tree['left']) and not isTree(tree['right']):
  8.         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
  9.         errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) + \
  10.         sum(power(rSet[:,-1] - tree['right'],2))
  11.         treeMean = (tree['left'] + tree['right']) / 2.0
  12.         errorMerge = sum(power(testData[:,-1] - treeMean,2))
  13.         if errorMerge < errorNoMerge:
  14.             print "merging"
  15.             return treeMean
  16.         else : return tree
  17.     else:return tree
复制代码
函数prune()有两个参数,待剪枝的树与剪枝所需的测试集。首先确认测试集是否为空,如果非空,则反复递归调用prune()对测试数据进行切分。检查某个分支到底是子树还是节点。如果是子树,就调用函数来对子树进行剪枝。再对左右两个分支完成剪枝后,还需要检查它们是否仍然还是子树,如果已经不再是子树,那么就可以进行合并。具体做法是对合并前后的误差进行比较,如果合并后的误差比不合并的小就进行合并操作,反之不合并直接返回。
  运行结果:
   

树回归

树回归-2-技术控-切分,dataSet,误差,数据,tree

  模型树

  用树来对数据进行建模,除了把叶节点设定为常数值外,还可以将其设定为分段线性函数,分段线性(piecewise linear)即模型由多个线性片段组成。
   

树回归

树回归-3-技术控-切分,dataSet,误差,数据,tree

  可以设计两条分别从0.0-0.3、从0.3~1.0的直线,得到两个线性模型,即分段线性模型。
  两条直线比很多节点组成一颗大树更容易理解。模型树的可解释性是它优于回归树的特点之一。模型树也具有更高的预测准确度。利用树生成算法对数据进行切分,且每份切分数据都能很容易被线性模型所表示,关键在于找到最佳切分。
  1. def linearSolve(dataSet):
  2.     m, n = shape(dataSet)
  3.     X = mat(ones((m,n)));
  4.     Y = mat(ones((m,1)))
  5.     X[:,1:n] = dataSet[:,0:n-1];
  6.     Y = dataSet[:,-1]
  7.     XTX = X.T * X
  8.     if linalg.det(XTX) == 0.0:
  9.         raise NameError('This matrix is singular, cannot do inverse, \n\
  10.         try increasing the second value of ops')
  11.     ws = XTX.I * (X.T * Y)
  12.     return ws, X, Y
  13. def modelLeaf(dataSet):
  14.     ws, X,Y = linearSolve(dataSet)
  15.     return ws
  16. def modelErr(dataSet):
  17.     ws, X,Y = linearSolve(dataSet)
  18.     yHat = X * ws
  19.     return sum(power(Y - yHat, 2))
复制代码
运行结果:
   

树回归

树回归-4-技术控-切分,dataSet,误差,数据,tree

  可以看到,该代码以0.285477为界创建了两个模型,而原图中的数据实际在0.3处分段,createTree()生成的这两个线性模型分别为:y = 3.468 + 1.1852x和0.0016985 + 11.96477x,与用于生成该睡的真是模型非常接近。
  该数据实际是由模型y = 3.5 + 1.0x 和 y = 0 + 12x再加上高斯噪声生成的。
   完整代码地址: https://github.com/JLUNeverMore/Tree_regression
奇葩朵朵向阳开 发表于 2016-10-15 14:34:40
突然觉得﹏戒蔔掉旳愛说的很有道理,赞一个!
回复 支持 反对

使用道具 举报

日付20网赚网 发表于 2016-10-22 00:47:36
鸳鸳相抱何时了,鸯在一边看热闹。  
回复 支持 反对

使用道具 举报

我要投稿

回页顶回复上一篇下一篇回列表
手机版/c.CoLaBug.com ( 粤ICP备05003221号 | 文网文[2010]257号 | 粤公网安备 44010402000842号 )

© 2001-2017 Comsenz Inc.

返回顶部 返回列表