统计学习方法之信息增益算法

本文算法思路来源于李航的《统计学习方法》,本文实现其算法
网上有很多思路,在这里我就不对原理展开论述,下面按照给出算法实现
①经验熵

//计算经验熵
private static double HD(Data datas[],int N){
    int iscredit = 0;
    for(int i = 0 ; i < N ; i++) {
        if (datas[i].agree==1) {
            iscredit++;
        }
    }
    if(iscredit==N)return 0;
    double p1 = (double)iscredit/N;
    double p2 = 1-p1;
    return ((-p1)*logto2(p1)-p2*logto2(p2));
}

private static double logto2(double i) {//2为底的对数
    return Math.log(i) / Math.log(2);
}

②分别算出信息增益(由经验熵减去条件熵)
因为每一种情况都不一样,为了方便,选择一个参数i来表示其选择哪一个特征值

private static double GDA(int i) {
    double HD = HD(datas, N);
    if (i == 1 ) {
        int N1 = 0, N2 = 0, N3 = 0;
        for (int j = 0; j < N; j++) {
            if (datas[j].year == 1) N1++;
            else if (datas[j].year == 2) N2++;
            else N3++;
        }
        Data datas1[] = new Data[N1];
        Data datas2[] = new Data[N2];
        Data datas3[] = new Data[N3];
        N1 = 0;
        N2 = 0;
        N3 = 0;
        for (int j = 0; j < N; j++) {
            if (datas[j].year == 1) {
                datas1[N1++] = datas[j];
            } else if (datas[j].year == 2) {
                datas2[N2++] = datas[j];
            } else {
                datas3[N3++] = datas[j];
            }

        }
        double HD1 = (double) N1 / N * HD(datas1, N1) + (double) N2 / N * HD(datas2, N2) + (double) N3 / N * HD(datas3, N3);
        return HD - HD1;
    } else if(i==2) {
        int N1 = 0, N2 = 0;
        for (int j = 0; j < N; j++) {
            if (datas[j].haveJob == 1) N1++;
            else N2++;
        }

        Data datas1[] = new Data[N1];
        Data datas2[] = new Data[N2];
        N1 = 0;
        N2 = 0;
        for (int j = 0; j < N; j++) {
            if (datas[j].haveJob == 1) {
                datas1[N1++] = datas[j];
            } else {
                datas2[N2++] = datas[j];
            }
        }

        double HD1 = (double) N1 / N * HD(datas1, N1) + (double) N2 / N * HD(datas2, N2) ;

        return HD - HD1;
    }else if(i==3){
        int N1 = 0, N2 = 0;
        for (int j = 0; j < N; j++) {
            if (datas[j].haveHome == 1) N1++;
            else N2++;
        }

        Data datas1[] = new Data[N1];
        Data datas2[] = new Data[N2];
        N1 = 0;
        N2 = 0;
        for (int j = 0; j < N; j++) {
            if (datas[j].haveHome == 1) {
                datas1[N1++] = datas[j];
            } else {
                datas2[N2++] = datas[j];
            }
        }

        double HD1 = (double) N1 / N * HD(datas1, N1) + (double) N2 / N * HD(datas2, N2) ;

        return HD - HD1;
    }else {
        int N1 = 0, N2 = 0, N3 = 0;
        for (int j = 0; j < N; j++) {
            if (datas[j].credit == 1) N1++;
            else if (datas[j].credit == 2) N2++;
            else N3++;
        }
        Data datas1[] = new Data[N1];
        Data datas2[] = new Data[N2];
        Data datas3[] = new Data[N3];
        N1 = 0;
        N2 = 0;
        N3 = 0;
        for (int j = 0; j < N; j++) {
            if (datas[j].credit == 1) {
                datas1[N1++] = datas[j];
            } else if (datas[j].credit == 2) {
                datas2[N2++] = datas[j];
            } else {
                datas3[N3++] = datas[j];
            }

        }
        double HD1 = (double) N1 / N * HD(datas1, N1) + (double) N2 / N * HD(datas2, N2) + (double) N3 / N * HD(datas3, N3);
        return HD - HD1;
    }
}
发布了30 篇原创文章 · 获赞 62 · 访问量 3085

猜你喜欢

转载自blog.csdn.net/weixin_43981664/article/details/90729418