机器学习 决策树篇——解决连续变量的分类问题

摘要

本文通过python实现了连续变量的信息熵条件熵熵增益熵增益率GiNi系数GiNi系数增益的计算、实现了连续变量的决策树分类模型,同样也适用于离散变量的分类模型,并将代码进行了封装,方便读者调用。

信息熵、条件熵、熵增益、熵增益率的计算

.cal_entropy():计算熵的函数
.cal_conditional_entropy():计算条件熵的函数
.cal_entropy_gain():计算熵增益(互信息)的函数
.cal_entropy_gain_ratio():计算熵增益率的函数
用法:可直接调用CyrusDecisionTree类方法或先将类实例化后再调用。(见后续代码)

GiNi系数、GiNi系数增益的计算

.cal_gini():计算GiNi系数的函数
.cal_gini_gain():计算GiNi系数增益的函数
用法:可直接调用CyrusDecisionTree类方法或先将类实例化后再调

python代码

criterion 为可选参数
若传入“C4.5”则为基于信息增益率的决策树模型
若传入“gini”则为基于GiNi系数增益的决策树模型

import numpy as np
import pandas as pd
class CyrusDecisionTree(object):
 X = None
 Y = None
 def __init__(self,criterion = "C4.5"):
     self.criterion = criterion
     self.tree_net = None
     
 # 1、计算信息熵的函数
 @classmethod
 def cal_entropy(class_obj,y):
     y = np.array(y).reshape(-1)
     counts = np.array(pd.Series(y).value_counts())
     return -((counts/y.shape[0])*np.log2(counts/y.shape[0])).sum()
 # 2、计算条件熵的函数
 @classmethod
 def cal_conditional_entropy(class_obj,x,y):
     """
     计算在条件x下y的信息熵
     """
     x = np.array(pd.Series(x).sort_values()).reshape(-1)
     y = np.array(y).reshape(-1)[list(pd.Series(x).argsort())]
     split = []
     entropy = []
     for i in range(x.shape[0]-1):
         split.append(0.5*(x[i]+x[i+1]))
         entropy.append((i+1)/x.shape[0]*class_obj.cal_entropy(y[:i+1])+(1-(i+1)/x.shape[0])*class_obj.cal_entropy(y[i+1:]))
     return (np.array(entropy),np.array(split))
 # 3、计算信息增益的函数
 @classmethod
 def cal_entropy_gain(class_obj,x,y):
     """
     计算在条件x下y的信息增益
     """
     entropy,split = class_obj.cal_conditional_entropy(x,y)
     entropy_gain = class_obj.cal_entropy(y) - entropy
     return entropy_gain.max(),split[entropy_gain.argmax()]
 # 4、计算熵增益率的函数
 @classmethod
 def cal_entropy_gain_ratio(class_obj,x,y):
     """
     计算在条件x下y的信息增益率
     """
     entropy_gain,split = class_obj.cal_entropy_gain(x,y)
     entropy_condition = class_obj.cal_entropy(y) - entropy_gain
     return entropy_gain/entropy_condition,split
 # 5、Gini系数计算函数
 @classmethod
 def cal_gini(class_obj,y):
     y = np.array(y).reshape(-1)
     counts = np.array(pd.Series(y).value_counts())
     return 1-(((counts/y.shape[0])**2).sum())
 # 6、Gini系数增益计算
 @classmethod
 def cal_gini_gain(class_obj,x,y):
     """
     计算在条件x下y的Gini系数增益
     """
     x = np.array(pd.Series(x).sort_values()).reshape(-1)
     y = np.array(y).reshape(-1)[list(pd.Series(x).argsort())]
     split = []
     gini = []
     for i in range(x.shape[0]-1):
         split.append(0.5*(x[i]+x[i+1]))
         gini.append((i+1)/x.shape[0]*class_obj.cal_gini(y[:i+1])+(1-(i+1)/x.shape[0])*class_obj.cal_gini(y[i+1:]))
     gini_gain = class_obj.cal_gini(y) - np.array(gini)
     split = np.array(split)
     return gini_gain.max(),split[gini_gain.argmax()]
 # tree构建递归函数
 def tree(self,x,y,net):
     if pd.Series(y).value_counts().shape[0] == 1:
         net.append(y[0])
     else:
         x_entropy = []
         x_split = []
         for i in range(x.shape[1]):
             if self.criterion == "C4.5":
                 entropy,split= self.cal_entropy_gain_ratio(x[:,i],y)
             else:
                 entropy,split= self.cal_gini_gain(x[:,i],y)
             x_entropy.append(entropy)
             x_split.append(split)
         rank = np.array(x_entropy).argmax()
         split = x_split[rank]
         net.append(rank)
         net.append(split)
         net.append([])
         net.append([])
         x_1 = []
         x_2 = []
         for i in range(x.shape[0]):
             if x[i,rank] > split:
                 x_1.append(i)
             else:
                 x_2.append(i)
         x1 = x[x_1,:]
         y1 = y[x_1]
         x2 = x[x_2,:]
         y2 = y[x_2]
         return self.tree(x1,y1,net[2]),self.tree(x2,y2,net[3]) 
 def predict_tree(self,x,net):
     x = np.array(x).reshape(-1)
     if len(net) == 1:
         return net
     else:
         if x[net[0]] >= net[1]:
             return self.predict_tree(x,net[2])
         else:
             return self.predict_tree(x,net[3])
 # 模型训练函数
 def fit(self,x,y):
     self.X = np.array(x)
     self.Y = np.array(y).reshape(-1)
     self.tree_net = []
     self.tree(self.X,self.Y,self.tree_net)
 # 模型预测函数
 def predict(self,x):
     x = np.array(x)
     pre_y = []
     for i in range(x.shape[0]):
         pre_y.append(self.predict_tree(x[i,:],self.tree_net))
     return np.array(pre_y)   

连续变量决策树分类案例

y = np.random.randint(0,10,30).reshape(-1)
x = np.random.random([30,5])
print(x)
print(y)
[[0.52533105 0.73209647 0.58700477 0.36033001 0.91586941]
 [0.94308921 0.13044845 0.34348716 0.68958107 0.85397988]
 [0.7242329  0.53027196 0.13577077 0.26769844 0.67871508]
 [0.43056763 0.57511585 0.12568578 0.31678452 0.0067388 ]
 [0.38103315 0.71300916 0.83360782 0.40604844 0.0352286 ]
 [0.39538199 0.79040881 0.63293679 0.67048469 0.0743981 ]
 [0.60237319 0.48057981 0.30906018 0.23632994 0.65723904]
 [0.64566226 0.95529741 0.34702771 0.45110142 0.0355881 ]
 [0.16776585 0.69377092 0.98103948 0.21491139 0.3792334 ]
 [0.48527149 0.16346686 0.71499249 0.24499424 0.43896129]
 [0.50378007 0.11929577 0.53185892 0.04572121 0.2287798 ]
 [0.75859512 0.53336214 0.64378837 0.82518598 0.96073149]
 [0.67140078 0.50990813 0.99593748 0.57135234 0.2955292 ]
 [0.60429891 0.30828858 0.4740352  0.97094536 0.73335159]
 [0.73112143 0.450134   0.66282747 0.93411235 0.27251284]
 [0.45273626 0.70515434 0.79901511 0.46209148 0.75002544]
 [0.75767042 0.16873059 0.81269049 0.16076081 0.6065813 ]
 [0.97628975 0.14158034 0.10692558 0.56774873 0.97330805]
 [0.49577763 0.52372332 0.34862    0.58616061 0.94039918]
 [0.08672443 0.40289412 0.07220557 0.16319812 0.39363945]
 [0.00317775 0.13165272 0.2509101  0.28256357 0.72483668]
 [0.7287063  0.35129312 0.44207534 0.23099126 0.08441964]
 [0.69944897 0.06905071 0.2411949  0.57971762 0.14470603]
 [0.67122102 0.17006905 0.68307124 0.89004399 0.76470935]
 [0.58254671 0.66576537 0.12318    0.84908671 0.84378037]
 [0.42940781 0.83785544 0.96820387 0.95913632 0.78881616]
 [0.62771109 0.25085264 0.91938847 0.27654677 0.95426724]
 [0.02575006 0.62735923 0.85298517 0.36904279 0.25085951]
 [0.05350246 0.66845444 0.74378456 0.81039401 0.40810988]
 [0.11843527 0.91711057 0.01975534 0.34762297 0.05685195]]
[3 3 9 9 2 3 2 5 1 2 4 2 0 0 9 9 7 2 2 9 9 0 5 4 0 3 1 6 6 6]
model = CyrusDecisionTree(criterion="gini")
model.fit(x,y)
y_pre = model.predict(x)
print(y)
print(y_pre)
[3 3 9 9 2 3 2 5 1 2 4 2 0 0 9 9 7 2 2 9 9 0 5 4 0 3 1 6 6 6]
[3 3 9 9 2 3 2 5 1 2 4 2 0 0 9 9 7 2 2 9 9 0 5 4 0 3 1 6 6 6]

by CyrusMay 2020 06 09

幸福不是多 而是遗忘
能遗忘生命的 烦恼忧伤
——————五月天(Enrich your life)——————

猜你喜欢

转载自blog.csdn.net/Cyrus_May/article/details/106631283
今日推荐