1 算法基本流程
关于算法过程,周志华老师的《机器学习》写的十分详细易懂,截图如下。
三种递归返回的情况:
-
判断当前数据集是否属于同一类别,无需划分
-
判断当前数据集的特征属性是否为空,或者,当前数据集在所有属性上的的取值相同,无法划分
-
当前结点包含的样本集合为空,不能划分
2 各部分函数
2.1 计算香农熵
1 from math import log 2 3 # 计算给定数集的香农熵/Calculating Shannon Entropy of a Given Number Set 4 def calcShannonEnt(dataSet): 5 numEntries=len(dataSet) 6 labelCounts={}#创建字典/Create a dictionary 7 for featVec in dataSet: 8 currentLabel=featVec[-1]#键值是最后一列的数值/The key value is the value of the last column 9 if currentLabel not in labelCounts.keys():10 labelCounts[currentLabel]=011 labelCounts[currentLabel]+=112 shannonEnt=0.013 for key in labelCounts:14 prob=float(labelCounts[key])/numEntries15 shannonEnt-=prob*log(prob,2)16 return shannonEnt
2.2 划分数据集
1 # 划分数据集/Dividing data sets2 def splitDataSet(dataSet,axis,value):# axis表示划分依据特征,value表示特征的值3 retDataSet=[]4 for featVec in dataSet:5 if featVec[axis]==value:6 reducedFeatVec=featVec[:axis]7 reducedFeatVec.extend(featVec[axis+1:])8 retDataSet.append(reducedFeatVec)9 return retDataSet#返回的是所有包含本特征的数据集(除去本特征列向量)
2.3 选择最好数据集划分方式
信息增益:
1 # 选择最好的数据集划分方式/Choose the best way to divide your data set 2 def chooseBestFeatureToSplit(dataSet): 3 numFeatures=len(dataSet[0])-1 4 baseEntropy=calcShannonEnt(dataSet) 5 bestFeature = -1 6 bestInfoGain=0.0 7 for i in range(numFeatures): 8 featList=[example[i] for example in dataSet] 9 uniqueVals=set(featList)10 newEntropy=0.011 for value in uniqueVals:12 subDataSet=splitDataSet(dataSet,i,value)13 prob=len(subDataSet)/float(len(dataSet))# 开始计算信息增益(ID3决策树)14 newEntropy+=prob*calcShannonEnt(subDataSet)15 infoGain=baseEntropy-newEntropy16 if(infoGain>bestInfoGain):17 bestInfoGain=infoGain18 bestFeature=i19 return bestFeature
3 创建决策树
1 import operator 2 # 创建一个树/creat a tree 3 def creatTree(dataSet,labels): 4 classList=[example[-1] for example in dataSet]#取列表最后一列值,即类标签 5 if classList.count(classList[0])==len(classList):#1判断是否所有类同标签 6 return classList[0] 7 if len(dataSet[0])==1:#2属性为空或使用完所有特征仍然无法将数据集划分仅包含唯一类别的分组,返回出现次数最多的组 8 return majorityCnt(classList) 9 bestFeat=chooseBestFeatureToSplit(dataSet)#选择最好的分组特征10 bestFeatLabel=labels[bestFeat]#最好的分组特征对应的标签11 myTree={bestFeatLabel:{}}#将标签存入myTree(一个嵌套字典)12 del(labels[bestFeat])#删除已存入tree中的原有标签13 featValues=[example[bestFeat] for example in dataSet]14 uniqueVals=set(featValues)#得到最好属性列中所有的值15 #递归执行16 for value in uniqueVals:17 subLabels=labels[:]18 myTree[bestFeatLabel][value]=creatTree(splitDataSet(dataSet,bestFeat,value),subLabels)19 return myTree
4 绘制树图
关于annotate():
4.1 获取叶节点数目和树深度
1 # 获取叶节点数目/Get the number of leaf nodes 2 def getNumLeafs(myTree): 3 numLeafs=0 4 firstStr=list(myTree.keys())[0]#此处与《机器学习实战》不同,python3 5 secondDict=myTree[firstStr] 6 for key in secondDict.keys(): 7 #利用type函数判断子节点是否为字典型 8 if type(secondDict[key]).__name__=='dict': 9 numLeafs+=getNumLeafs(secondDict[key])10 else:11 numLeafs+=112 return numLeafs13 14 # 获取树的层数/Get the number of layers in the tree15 def getTreeDepth(myTree):16 maxDepth=017 firstStr=list(myTree.keys())[0]18 secondDict=myTree[firstStr]19 for key in secondDict.keys():20 #利用type函数判断子节点是否为字典型21 if type(secondDict[key]).__name__=='dict':22 thisDepth=1+getTreeDepth(secondDict[key])23 else:24 thisDepth=125 if thisDepth>maxDepth:26 maxDepth=thisDepth27 return maxDepth
4.2 父子节点间插入文本信息
1 def plotTree(myTree,parentPt,nodeTxt): 2 numLeafs=getNumLeafs(myTree) 3 depth=getTreeDepth(myTree) 4 firstStr=list(myTree.keys())[0] 5 cntrPt=(plotTree.xoff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yoff) 6 plotMidText(cntrPt,parentPt,nodeTxt) 7 plotNode(firstStr,cntrPt,parentPt,decisionNode) 8 secondDict=myTree[firstStr] 9 plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD10 for key in secondDict.keys():11 if type(secondDict[key]).__name__ == 'dict':12 # 递归绘制树13 plotTree(secondDict[key], cntrPt, str(key))14 else:15 # 更新x的偏移量,每个叶子结点x轴方向上的距离为 1/plotTree.totalW16 plotTree.xoff = plotTree.xoff + 1.0 / plotTree.totalW17 # 绘制非叶子节点18 plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), \19 cntrPt, leafNode)20 # 绘制箭头上的标志21 plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))22 plotTree.yoff = plotTree.yoff + 1.0 / plotTree.totalD23 24 # 绘制决策树,inTree的格式为{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}25 def createPlot(inTree):26 # 新建一个figure设置背景颜色为白色27 fig = plt.figure(1, facecolor='white')28 # 清除figure29 fig.clf()30 axprops = dict(xticks=[], yticks=[])31 # 创建一个1行1列1个figure,并把网格里面的第一个figure的Axes实例返回给ax1作为函数createPlot()32 # 的属性,这个属性ax1相当于一个全局变量,可以给plotNode函数使用33 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)34 # 获取树的叶子节点35 plotTree.totalW = float(getNumLeafs(inTree))36 # 获取树的深度37 plotTree.totalD = float(getTreeDepth(inTree))38 # 节点的x轴的偏移量为-1/plotTree.totlaW/2,1为x轴的长度,除以2保证每一个节点的x轴之间的距离为1/plotTree.totlaW*239 plotTree.xoff = -0.5 / plotTree.totalW40 plotTree.yoff = 1.041 plotTree(inTree, (0.5, 1.0), '')42 plt.show()
4.3 结果
5 分类
5.1 分类函数
1 # 决策树分类函数/Decision tree classification function 2 def classify(inputTree,featLabels,testVec): 3 firstStr=list(inputTree.keys())[0] 4 secondDict=inputTree[firstStr] 5 featIndex=featLabels.index(firstStr)#index返回查找对象的索引位置,如果没有找到对象则抛出异常 6 for key in secondDict.keys(): 7 if testVec[featIndex]==key: 8 if type(secondDict[key]).__name__=='dict': 9 classLabel=classify(secondDict[key],featLabels,testVec)10 else:11 classLabel=secondDict[key]12 return classLabel
5.2 决策树的存储
序列化对象可以在磁盘上保存对象,使用python模块pickle序列化对象。
关于序列化对象:
1 # 决策树的存储/Decision tree storage 2 def storeTree(inputTree,filename): 3 import pickle 4 fw=open(filename,'w') 5 pickle.dump(inputTree,fw) 6 fw.close() 7 8 def grabTree(filename) 9 import pickle10 fr=open(filename)11 return pickle.load(fr)