Python编程实现基于基尼指数进行划分选择的决策树(CART决策树)算法

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/john_bian/article/details/100586245

本文是周志华老师的《机器学习》一书中第4章 决策树 的课后题第4.4题的实现。原题是:

试编程实现基于基尼指数进行划分选择的决策树算法,为表4.2中的数据生成预剪枝、后剪枝决策树,并与未剪枝决策树进行比较。

与ID3算法选择信息增益作为选择最优属性的标准不同,CART决策树选择使划分后基尼指数(Gini index)最小的属性作为最优划分属性。假设当前样本集合D中第k类样本所占的比例为p_{k} (k=1, 2, ...,\left | y \right |),则D的纯度可以用基尼值来度量:

Gini(D)=\sum^{|y|}_{k=1}\sum_{k^{'}\neq k}{p_{k}p_{k^{'}}}=1-\sum^{|y|}_{k=1}{p^{2}_{k}}

Gini(D)反映了从数据集D中随机抽取两个样本,这两个样本不属于同一类的概率,因此Gini(D)越小,则数据集D的纯度越高。

假定离散的属性aV个可能的取值\{a^{1}, a^{2}, ..., a^{V}\},若使用a来对样本集D来进行划分,则会产生V个分支结点,其中第v个分支结点包含了D中所有在属性a上取值为a^{v}的样本,记为D^{v},则属性a的基尼指数定义为

Gini\_index(D, a)=\sum^{V}_{v=1}{\frac{D^{v}}{D}}Gini(D^{v})

如果数据集中有取值范围是连续数值的属性,我们仍然需要使用二分法来寻找最佳的分隔点。

西瓜数据集2.0的可用版本如下所示

def watermelon2():
    train_data = [
        ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],
        ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'],
        ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],
        ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'],
        ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '是'],
        ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '否'],
        ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '否'],
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '否'],
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '否'],
        ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '否']
    ]

    test_data = [
        ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'],
        ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '是'],
        ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '否'],
        ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '否'],
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '否'],
        ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '否'],
    ]

    labels = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']

    return train_data, test_data, labels

我在实现CART决策树的时候,使用了和之前Python编程实现基于信息熵进行划分选择的决策树算法中相同的决策树结点结构TreeNode:

扫描二维码关注公众号,回复: 7200233 查看本文章
class TreeNode:
    """
    决策树结点类
    """
    current_index = 0

    def __init__(self, parent=None, attr_name=None, children=None, judge=None, split=None, data_index=None,
                 attr_value=None, rest_attribute=None):
        """
        决策树结点类初始化方法
        :param parent: 父节点
        """
        self.parent = parent  # 父节点,根节点的父节点为 None
        self.attribute_name = attr_name  # 本节点上进行划分的属性名
        self.attribute_value = attr_value  # 本节点上划分属性的值,是与父节点的划分属性名相对应的
        self.children = children  # 孩子结点列表
        self.judge = judge  # 如果是叶子结点,需要给出判断
        self.split = split  # 如果是使用连续属性进行划分,需要给出分割点
        self.data_index = data_index  # 对应训练数据集的训练索引号
        self.index = TreeNode.current_index  # 当前结点的索引号,方便输出时查看
        self.rest_attribute = rest_attribute  # 尚未使用的属性列表
        TreeNode.current_index += 1

    def to_string(self):
        """用一个字符串来描述当前结点信息"""
        this_string = 'current index : ' + str(self.index) + ";\n"
        if not (self.parent is None):
            parent_node = self.parent
            this_string = this_string + 'parent index : ' + str(parent_node.index) + ";\n"
            this_string = this_string + str(parent_node.attribute_name) + " : " + str(self.attribute_value) + ";\n"
        this_string = this_string + "data : " + str(self.data_index) + ";\n"
        if not(self.children is None):
            this_string = this_string + 'select attribute is : ' + str(self.attribute_name) + ";\n"
            child_list = []
            for child in self.children:
                child_list.append(child.index)
            this_string = this_string + 'children : ' + str(child_list)
        if not (self.judge is None):
            this_string = this_string + 'label : ' + self.judge
        return this_string

以下是不进行剪枝的CART决策树的主要实现代码cart.py:

# CART决策树,使用基尼指数(Gini index)来选择划分属性
# 分别实现预剪枝、后剪枝和不进行剪枝的实现

import math
from Ch04DecisionTree import TreeNode
from Ch04DecisionTree import Dataset


def is_number(s):
    """判断一个字符串是否为数字,如果是数字,我们认为这个属性的值是连续的"""
    try:
        float(s)
        return True
    except ValueError:
        pass
    return False


def gini(labels=[]):
    """
    计算数据集的基尼值
    :param labels: 数据集的类别标签
    :return:
    """
    data_count = {}
    for label in labels:
        if data_count.__contains__(label):
            data_count[label] += 1
        else:
            data_count[label] = 1

    n = len(labels)
    if n == 0:
        return 0

    gini_value = 1
    for key, value in data_count.items():
        gini_value = gini_value - (value/n)*(value/n)

    return gini_value


def gini_index_basic(n, attr_labels={}):
    gini_value = 0
    for attribute, labels in attr_labels.items():
        gini_value = gini_value + len(labels) / n * gini(labels)

    return gini_value


def gini_index(attributes=[], labels=[], is_value=False):
    """
    计算一个属性的基尼指数
    :param attributes: 当前数据在该属性上的属性值列表
    :param labels:
    :param is_value:
    :return:
    """
    n = len(labels)
    attr_labels = {}
    gini_value = 0  # 最终要返回的结果
    split = None  #

    if is_value:  # 属性值是连续的数值
        sorted_attributes = attributes.copy()
        sorted_attributes.sort()
        split_points = []
        for i in range(0, n-1):
            split_points.append((sorted_attributes[i+1]+sorted_attributes[i])/2)

        split_evaluation = []
        for current_split in split_points:
            low_labels = []
            up_labels = []
            for i in range(0, n):
                if attributes[i] <= current_split:
                    low_labels.append(labels[i])
                else:
                    up_labels.append(labels[i])
            attr_labels = {'small': low_labels, 'large': up_labels}
            split_evaluation.append(gini_index_basic(n, attr_labels=attr_labels))

        gini_value = min(split_evaluation)
        split = split_points[split_evaluation.index(gini_value)]

    else:  # 属性值是离散的词
        for i in range(0, n):
            if attr_labels.__contains__(attributes[i]):
                temp_list = attr_labels[attributes[i]]
                temp_list.append(labels[i])
            else:
                temp_list = []
                temp_list.append(labels[i])
                attr_labels[attributes[i]] = temp_list

        gini_value = gini_index_basic(n, attr_labels=attr_labels)

    return gini_value, split


def finish_node(current_node=TreeNode.TreeNode(), data=[], label=[]):
    """
    完成一个结点上的计算
    :param current_node: 当前计算的结点
    :param data: 数据集
    :param label: 数据集的 label
    :return:
    """
    n = len(label)

    # 判断当前结点中的数据是否属于同一类
    one_class = True
    this_data_index = current_node.data_index

    for i in this_data_index:
        for j in this_data_index:
            if label[i] != label[j]:
                one_class = False
                break
        if not one_class:
            break
    if one_class:
        current_node.judge = label[this_data_index[0]]
        return

    rest_title = current_node.rest_attribute  # 候选属性
    if len(rest_title) == 0:  # 如果候选属性为空,则是个叶子结点。需要选择最多的那个类作为该结点的类
        label_count = {}
        temp_data = current_node.data_index
        for index in temp_data:
            if label_count.__contains__(label[index]):
                label_count[label[index]] += 1
            else:
                label_count[label[index]] = 1
        final_label = max(label_count)
        current_node.judge = final_label
        return

    title_gini = {}  # 记录每个属性的基尼指数
    title_split_value = {}  # 记录每个属性的分隔值,如果是连续属性则为分隔值,如果是离散属性则为None
    for title in rest_title:
        attr_values = []
        current_label = []
        for index in current_node.data_index:
            this_data = data[index]
            attr_values.append(this_data[title])
            current_label.append(label[index])
        temp_data = data[0]
        this_gain, this_split_value = gini_index(attr_values, current_label, is_number(temp_data[title]))  # 如果属性值为数字,则认为是连续的
        title_gini[title] = this_gain
        title_split_value[title] = this_split_value

    best_attr = min(title_gini, key=title_gini.get)  # 基尼指数最小的属性名
    current_node.attribute_name = best_attr
    current_node.split = title_split_value[best_attr]
    rest_title.remove(best_attr)

    a_data = data[0]
    if is_number(a_data[best_attr]):  # 如果是该属性的值为连续数值
        split_value = title_split_value[best_attr]
        small_data = []
        large_data = []
        for index in current_node.data_index:
            this_data = data[index]
            if this_data[best_attr] <= split_value:
                small_data.append(index)
            else:
                large_data.append(index)
        small_str = '<=' + str(split_value)
        large_str = '>' + str(split_value)
        small_child = TreeNode.TreeNode(parent=current_node, data_index=small_data, attr_value=small_str,
                               rest_attribute=rest_title.copy())
        large_child = TreeNode.TreeNode(parent=current_node, data_index=large_data, attr_value=large_str,
                               rest_attribute=rest_title.copy())
        current_node.children = [small_child, large_child]

    else:  # 如果该属性的值是离散值
        best_titlevalue_dict = {}  # key是属性值的取值,value是个list记录所包含的样本序号
        for index in current_node.data_index:
            this_data = data[index]
            if best_titlevalue_dict.__contains__(this_data[best_attr]):
                temp_list = best_titlevalue_dict[this_data[best_attr]]
                temp_list.append(index)
            else:
                temp_list = [index]
                best_titlevalue_dict[this_data[best_attr]] = temp_list

        children_list = []
        for key, index_list in best_titlevalue_dict.items():
            a_child = TreeNode.TreeNode(parent=current_node, data_index=index_list, attr_value=key,
                               rest_attribute=rest_title.copy())
            children_list.append(a_child)
        current_node.children = children_list

    # print(current_node.to_string())
    for child in current_node.children:  # 递归
        finish_node(child, data, label)


def cart_tree(Data, title, label):
    """
    生成一颗 CART 决策树
    :param Data: 数据集,每个样本是一个 dict(属性名:属性值),整个 Data 是个大的 list
    :param title:   每个属性的名字,如 色泽、含糖率等
    :param label: 存储的是每个样本的类别
    :return:
    """
    n = len(Data)
    rest_title = title.copy()
    root_data = []
    for i in range(0, n):
        root_data.append(i)

    root_node = TreeNode.TreeNode(data_index=root_data, rest_attribute=title.copy())
    finish_node(root_node, Data, label)

    return root_node


def print_tree(root=TreeNode.TreeNode()):
    """
    打印输出一颗树
    :param root: 根节点
    :return:
    """
    node_list = [root]
    while(len(node_list)>0):
        current_node = node_list[0]
        print('--------------------------------------------')
        print(current_node.to_string())
        print('--------------------------------------------')
        children_list = current_node.children
        if not (children_list is None):
            for child in children_list:
                node_list.append(child)
        node_list.remove(current_node)


def classify_data(decision_tree=TreeNode.TreeNode(), x={}):
    """
    使用决策树判断一个数据样本的类别标签
    :param decision_tree: 决策树的根节点
    :param x: 要进行判断的样本
    :return:
    """
    current_node = decision_tree
    while current_node.judge is None:
        if current_node.split is None:  # 离散属性
            can_judge = False  # 如果训练数据集不够大,测试数据集中可能会有在训练数据集中没有出现过的属性值
            for child in current_node.children:
                if child.attribute_value == x[current_node.attribute_name]:
                    current_node = child
                    can_judge = True
                    break
            if not can_judge:
                return None
        else:
            child_list = current_node.children
            if x[current_node.attribute_name] <= current_node.split:
                current_node = child_list[0]
            else:
                current_node = child_list[1]

    return current_node.judge


def run_test():
    train_watermelon, test_watermelon, title = Dataset.watermelon2()

    # 先处理数据
    train_data = []
    test_data = []
    train_label = []
    test_label = []
    for melon in train_watermelon:
        a_dict = {}
        dim = len(melon) - 1
        for i in range(0, dim):
            a_dict[title[i]] = melon[i]
        train_data.append(a_dict)
        train_label.append(melon[dim])
    for melon in test_watermelon:
        a_dict = {}
        dim = len(melon) - 1
        for i in range(0, dim):
            a_dict[title[i]] = melon[i]
        test_data.append(a_dict)
        test_label.append(melon[dim])

    decision_tree = cart_tree(train_data, title, train_label)
    print('训练的决策树是:')
    print_tree(decision_tree)
    print('\n')

    test_judge = []
    for melon in test_data:
        test_judge.append(classify_data(decision_tree, melon))
    print('决策树在测试数据集上的分类结果是:', test_judge)
    print('测试数据集的正确类别信息应该是:  ', test_label)

    accuracy = 0
    for i in range(0, len(test_label)):
        if test_label[i] == test_judge[i]:
            accuracy += 1
    accuracy /= len(test_label)
    print('决策树在测试数据集上的分类正确率为:'+str(accuracy*100)+"%")


if __name__ == '__main__':
    run_test()

在西瓜数据集2.0上的运行结果如下所示:

训练的决策树是:
--------------------------------------------
current index : 3;
data : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
select attribute is : 色泽;
children : [4, 5, 6]
--------------------------------------------
--------------------------------------------
current index : 4;
parent index : 3;
色泽 : 青绿;
data : [0, 3, 5, 9];
select attribute is : 敲声;
children : [7, 8, 9]
--------------------------------------------
--------------------------------------------
current index : 5;
parent index : 3;
色泽 : 乌黑;
data : [1, 2, 4, 7];
select attribute is : 根蒂;
children : [10, 11]
--------------------------------------------
--------------------------------------------
current index : 6;
parent index : 3;
色泽 : 浅白;
data : [6, 8];
label : 否
--------------------------------------------
--------------------------------------------
current index : 7;
parent index : 4;
敲声 : 浊响;
data : [0, 3];
label : 是
--------------------------------------------
--------------------------------------------
current index : 8;
parent index : 4;
敲声 : 清脆;
data : [5];
label : 否
--------------------------------------------
--------------------------------------------
current index : 9;
parent index : 4;
敲声 : 沉闷;
data : [9];
label : 否
--------------------------------------------
--------------------------------------------
current index : 10;
parent index : 5;
根蒂 : 蜷缩;
data : [1, 2];
label : 是
--------------------------------------------
--------------------------------------------
current index : 11;
parent index : 5;
根蒂 : 稍蜷;
data : [4, 7];
select attribute is : 纹理;
children : [12, 13]
--------------------------------------------
--------------------------------------------
current index : 12;
parent index : 11;
纹理 : 稍糊;
data : [4];
label : 是
--------------------------------------------
--------------------------------------------
current index : 13;
parent index : 11;
纹理 : 清晰;
data : [7];
label : 否
--------------------------------------------


决策树在测试数据集上的分类结果是: ['否', '否', '否', '是', '否', '否', '是']
测试数据集的正确类别信息应该是:   ['是', '是', '是', '否', '否', '否', '否']
决策树在测试数据集上的分类正确率为:28.57142857142857%

预剪枝的CART决策树实现可见 Python编程实现预剪枝的CART决策树

猜你喜欢

转载自blog.csdn.net/john_bian/article/details/100586245
今日推荐