机器学习 决策树篇——解决连续变量的分类问题
摘要
本文通过python实现了连续变量的信息熵、条件熵、熵增益、熵增益率、GiNi系数、GiNi系数增益的计算、实现了连续变量的决策树分类模型,同样也适用于离散变量的分类模型,并将代码进行了封装,方便读者调用。
信息熵、条件熵、熵增益、熵增益率的计算
.cal_entropy():计算熵的函数
.cal_conditional_entropy():计算条件熵的函数
.cal_entropy_gain():计算熵增益(互信息)的函数
.cal_entropy_gain_ratio():计算熵增益率的函数
用法:可直接调用CyrusDecisionTree类方法或先将类实例化后再调用。(见后续代码)
GiNi系数、GiNi系数增益的计算
.cal_gini():计算GiNi系数的函数
.cal_gini_gain():计算GiNi系数增益的函数
用法:可直接调用CyrusDecisionTree类方法或先将类实例化后再调
python代码
criterion 为可选参数
若传入“C4.5”则为基于信息增益率的决策树模型
若传入“gini”则为基于GiNi系数增益的决策树模型
import numpy as np
import pandas as pd
class CyrusDecisionTree(object):
X = None
Y = None
def __init__(self,criterion = "C4.5"):
self.criterion = criterion
self.tree_net = None
# 1、计算信息熵的函数
@classmethod
def cal_entropy(class_obj,y):
y = np.array(y).reshape(-1)
counts = np.array(pd.Series(y).value_counts())
return -((counts/y.shape[0])*np.log2(counts/y.shape[0])).sum()
# 2、计算条件熵的函数
@classmethod
def cal_conditional_entropy(class_obj,x,y):
"""
计算在条件x下y的信息熵
"""
x = np.array(pd.Series(x).sort_values()).reshape(-1)
y = np.array(y).reshape(-1)[list(pd.Series(x).argsort())]
split = []
entropy = []
for i in range(x.shape[0]-1):
split.append(0.5*(x[i]+x[i+1]))
entropy.append((i+1)/x.shape[0]*class_obj.cal_entropy(y[:i+1])+(1-(i+1)/x.shape[0])*class_obj.cal_entropy(y[i+1:]))
return (np.array(entropy),np.array(split))
# 3、计算信息增益的函数
@classmethod
def cal_entropy_gain(class_obj,x,y):
"""
计算在条件x下y的信息增益
"""
entropy,split = class_obj.cal_conditional_entropy(x,y)
entropy_gain = class_obj.cal_entropy(y) - entropy
return entropy_gain.max(),split[entropy_gain.argmax()]
# 4、计算熵增益率的函数
@classmethod
def cal_entropy_gain_ratio(class_obj,x,y):
"""
计算在条件x下y的信息增益率
"""
entropy_gain,split = class_obj.cal_entropy_gain(x,y)
entropy_condition = class_obj.cal_entropy(y) - entropy_gain
return entropy_gain/entropy_condition,split
# 5、Gini系数计算函数
@classmethod
def cal_gini(class_obj,y):
y = np.array(y).reshape(-1)
counts = np.array(pd.Series(y).value_counts())
return 1-(((counts/y.shape[0])**2).sum())
# 6、Gini系数增益计算
@classmethod
def cal_gini_gain(class_obj,x,y):
"""
计算在条件x下y的Gini系数增益
"""
x = np.array(pd.Series(x).sort_values()).reshape(-1)
y = np.array(y).reshape(-1)[list(pd.Series(x).argsort())]
split = []
gini = []
for i in range(x.shape[0]-1):
split.append(0.5*(x[i]+x[i+1]))
gini.append((i+1)/x.shape[0]*class_obj.cal_gini(y[:i+1])+(1-(i+1)/x.shape[0])*class_obj.cal_gini(y[i+1:]))
gini_gain = class_obj.cal_gini(y) - np.array(gini)
split = np.array(split)
return gini_gain.max(),split[gini_gain.argmax()]
# tree构建递归函数
def tree(self,x,y,net):
if pd.Series(y).value_counts().shape[0] == 1:
net.append(y[0])
else:
x_entropy = []
x_split = []
for i in range(x.shape[1]):
if self.criterion == "C4.5":
entropy,split= self.cal_entropy_gain_ratio(x[:,i],y)
else:
entropy,split= self.cal_gini_gain(x[:,i],y)
x_entropy.append(entropy)
x_split.append(split)
rank = np.array(x_entropy).argmax()
split = x_split[rank]
net.append(rank)
net.append(split)
net.append([])
net.append([])
x_1 = []
x_2 = []
for i in range(x.shape[0]):
if x[i,rank] > split:
x_1.append(i)
else:
x_2.append(i)
x1 = x[x_1,:]
y1 = y[x_1]
x2 = x[x_2,:]
y2 = y[x_2]
return self.tree(x1,y1,net[2]),self.tree(x2,y2,net[3])
def predict_tree(self,x,net):
x = np.array(x).reshape(-1)
if len(net) == 1:
return net
else:
if x[net[0]] >= net[1]:
return self.predict_tree(x,net[2])
else:
return self.predict_tree(x,net[3])
# 模型训练函数
def fit(self,x,y):
self.X = np.array(x)
self.Y = np.array(y).reshape(-1)
self.tree_net = []
self.tree(self.X,self.Y,self.tree_net)
# 模型预测函数
def predict(self,x):
x = np.array(x)
pre_y = []
for i in range(x.shape[0]):
pre_y.append(self.predict_tree(x[i,:],self.tree_net))
return np.array(pre_y)
连续变量决策树分类案例
y = np.random.randint(0,10,30).reshape(-1)
x = np.random.random([30,5])
print(x)
print(y)
[[0.52533105 0.73209647 0.58700477 0.36033001 0.91586941]
[0.94308921 0.13044845 0.34348716 0.68958107 0.85397988]
[0.7242329 0.53027196 0.13577077 0.26769844 0.67871508]
[0.43056763 0.57511585 0.12568578 0.31678452 0.0067388 ]
[0.38103315 0.71300916 0.83360782 0.40604844 0.0352286 ]
[0.39538199 0.79040881 0.63293679 0.67048469 0.0743981 ]
[0.60237319 0.48057981 0.30906018 0.23632994 0.65723904]
[0.64566226 0.95529741 0.34702771 0.45110142 0.0355881 ]
[0.16776585 0.69377092 0.98103948 0.21491139 0.3792334 ]
[0.48527149 0.16346686 0.71499249 0.24499424 0.43896129]
[0.50378007 0.11929577 0.53185892 0.04572121 0.2287798 ]
[0.75859512 0.53336214 0.64378837 0.82518598 0.96073149]
[0.67140078 0.50990813 0.99593748 0.57135234 0.2955292 ]
[0.60429891 0.30828858 0.4740352 0.97094536 0.73335159]
[0.73112143 0.450134 0.66282747 0.93411235 0.27251284]
[0.45273626 0.70515434 0.79901511 0.46209148 0.75002544]
[0.75767042 0.16873059 0.81269049 0.16076081 0.6065813 ]
[0.97628975 0.14158034 0.10692558 0.56774873 0.97330805]
[0.49577763 0.52372332 0.34862 0.58616061 0.94039918]
[0.08672443 0.40289412 0.07220557 0.16319812 0.39363945]
[0.00317775 0.13165272 0.2509101 0.28256357 0.72483668]
[0.7287063 0.35129312 0.44207534 0.23099126 0.08441964]
[0.69944897 0.06905071 0.2411949 0.57971762 0.14470603]
[0.67122102 0.17006905 0.68307124 0.89004399 0.76470935]
[0.58254671 0.66576537 0.12318 0.84908671 0.84378037]
[0.42940781 0.83785544 0.96820387 0.95913632 0.78881616]
[0.62771109 0.25085264 0.91938847 0.27654677 0.95426724]
[0.02575006 0.62735923 0.85298517 0.36904279 0.25085951]
[0.05350246 0.66845444 0.74378456 0.81039401 0.40810988]
[0.11843527 0.91711057 0.01975534 0.34762297 0.05685195]]
[3 3 9 9 2 3 2 5 1 2 4 2 0 0 9 9 7 2 2 9 9 0 5 4 0 3 1 6 6 6]
model = CyrusDecisionTree(criterion="gini")
model.fit(x,y)
y_pre = model.predict(x)
print(y)
print(y_pre)
[3 3 9 9 2 3 2 5 1 2 4 2 0 0 9 9 7 2 2 9 9 0 5 4 0 3 1 6 6 6]
[3 3 9 9 2 3 2 5 1 2 4 2 0 0 9 9 7 2 2 9 9 0 5 4 0 3 1 6 6 6]
by CyrusMay 2020 06 09
幸福不是多 而是遗忘
能遗忘生命的 烦恼忧伤
——————五月天(Enrich your life)——————