机器学习实战——决策树(完整代码)

阅读: 评论:0

机器学习实战——决策树(完整代码)
声明: 此笔记是学习《机器学习实战》 —— Peter Harrington 上的实例并结合西⽠书上的理论知识来完成,使⽤Python3 ,会与书上⼀些地⽅不⼀样。
机器学习实战—— 决策树
Coding: Jungle
样本集合:第k类样本所占⽐例L: 属性a对样本D进⾏划分产⽣分⽀节点个数:信息熵 :信息增益:数据
不浮出⽔⾯是否可以⽣存
是否有脚蹼
是否属于鱼类
汽车天地
何鲁丽同志简历1是是是2是是是3是否否4否是否4
1. 计算给定数据集的熵
机床电气原理图
#trees.py
from  math import  log def  calShannonEnt (dataSet ):    numEntries = len (dataSet )    labelCounts = {}
#为所有可能的分类创建字典    for  featVec in  dataSet :        currentLabel = featVec [-1]
if  currentLabel not  in  labelCounts .keys ():            labelCounts [currentLabel ] = 0        labelCounts [currentLabel ] += 1    shannonEnt = 0.0    for  key in  labelCounts :        #计算熵,先求p
prob = float (labelCounts [key ])/numEntries        shannonEnt -= prob *log (prob ,2)    return  shannonEnt
2. 构建数据集
D
p k
V Ent (D )=−p log p ∑k =1∣y ∣
k 2k Gain (D ,a )=Ent (D )−
Ent (D )
∑v =1V
∣D ∣∣D ∣v v
def creatDataSet():
dataSet =[[1,1,'maybe'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels =['no surfacing','flippers']
return dataSet,labels
myData,labels = creatDataSet()
print("数据集:{}\n 标签:{}".format(myData,labels))
print("该数据集下的⾹农熵为:{}".format(calShannonEnt(myData)))
数据集:[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
标签:['no surfacing', 'flippers']
该数据集下的⾹农熵为:1.3709505944546687
相同数据量下,减少属性值类型及特征值,对⽐熵的变化
myData[0][-1]='yes'
print("数据为:{}\n 该数据集下的⾹农熵为:{}".format(myData,calShannonEnt(myData)))
数据为:[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
该数据集下的⾹农熵为:0.9709505944546686
3. 划分数据集
# 根据属性及其属性值划分数据集
def splitDataSet(dataSet, axis, value):
'''dataSet : 待划分的数据集
axis : 属性及特征
value : 属性值及特征的hasattr值'''
retDataSet =[]
for featVet in dataSet:
if featVet[axis]== value:
reducedFeatVec = featVet[:axis]
retDataSet.append(reducedFeatVec)
return retDataSet
print("划分前的数据集:{}\n \n按照“离开⽔是否能⽣存”为划分属性,得到下⼀层待划分的结果为:\n{}--------{}".format(myData,splitDataSet(myData,0,0),splitDa taSet(myData,0,1)))
划分前的数据集:[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
按照“离开⽔是否能⽣存”为划分属性,得到下⼀层待划分的结果为:
[[1, 'no'], [1, 'no']]--------[[1, 'yes'], [1, 'yes'], [0, 'no']]
# 选择最好的数据集划分⽅式,及根绝信息增益选择划分属性
def chooseBestFeatureToSplit(dataSet):
numFeatures =len(dataSet[0])-1
baseEntropy = calShannonEnt(dataSet)
bestInfoGain, bestFeature =0,-1
for i in range(numFeatures):
featList =[example[i]for example in dataSet]
uniqueVals =set(featList)
newEntropy =0.0
# 计算每种划分⽅式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob =len(subDataSet)/float(len(dataSet))
newEntropy += prob * calShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
chooseBestFeatureToSplit(myData)
递归构建决策树
# 到出现次数最多的分类名称
import operator
def majorityCnt(classList):
classCount ={}
for vote in classList:
压电陶瓷驱动电源if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1
sortedClassCount =sorted(网络节点
classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]
# 创建树的函数
def creatTree(dataSet, labels):
classList =[example[-1]for example in dataSet]
# 类别完全相同停⽌划分
unt(classList[0])==len(classList):
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree ={bestFeatLabel:{}}
del(labels[bestFeat])
featValues =[example[bestFeat]for example in dataSet]
uniqueVals =set(featValues)
for value in uniqueVals:
sublabels = labels[:]
myTree[bestFeatLabel][value]= creatTree(
splitDataSet(dataSet, bestFeat, value), sublabels)
病毒唑注射液return myTree
creatTree(myData,labels)
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
# treePlotter.py
import matplotlib.pyplot as plt
from pylab import*
decisionNode =dict(box, fc="0.8")
leafNode =dict(box, fc="0.8")
arrow_args =dict(arrow)
def plotNode(nodeTxt, centerPt,parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot():
fig = plt.figure(111, facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode(U'Decision Node',(0.5,0.1),(0.1,0.5), decisionNode)
plotNode(U'Leaf Node',(0.8,0.1),(0.3,0.8),  leafNode)
plt.show
createPlot()
[外链图⽚转存失败,源站可能有防盗链机制,建议将图⽚保存下来直接上传(img-2QZJ7Upl-1596631929903)(output_19_0.png)]
# 计算叶节点的个数
def getNumLeaves(myTree):
numLeafs=0
# 截取到树字典中的key值
#firstStr = str(myTree.keys())[13:-3]
firstStr =eval(str(myTree.keys()).replace('dict_keys(','').replace(')',''))[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ =='dict':
numLeafs += getNumLeaves(secondDict[key])
else:
numLeafs +=1
return numLeafs
# 计算树的深度
def getTreeDepth(myTree):
maxDepth =0
firstStr =eval(str(myTree.keys()).replace('dict_keys(','').replace(')',''))[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ =='dict':
thisDepth =1+ getTreeDepth(secondDict[key])
else:
thisDepth =1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
#测试深度计算和叶节点记述函数
def retrieveTree(i):
listOftrees =[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}] return listOftrees[i]
mytree =  retrieveTree(1)
getNumLeaves(mytree)
getTreeDepth(mytree)
3

本文发布于:2023-07-12 02:54:43,感谢您对本站的认可!

本文链接:https://patent.en369.cn/xueshu/204472.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:划分   属性   是否   数据   计算   学习   节点
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 369专利查询检索平台 豫ICP备2021025688号-20 网站地图