决策树分类算法:C4.5算法
【每次以信息增益率最大的特征项Ai为节点建立决策树】
【决策树算法思路参考】
决策树分类算法公共基类
```java
package base;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
public abstract class GeneralDecTreeHandler {
protected abstract int chooseBestFeatureToSplit(Matrix_2D<String> dataSet);
//划分数据集,提取第index维值等于val的行,并去除第index维
protected Matrix_2D<String> splitDataSet(Matrix_2D<String> dataSet,int index,String val) {
Matrix_2D<String> retDataSet=new Matrix_2D<String>();
int row=dataSet.getRowDimension();
for(int i=0;i<row;++i) {
if(dataSet.get(i).get(index).equals(val)) {
ArrayList<String> tempLine=new ArrayList<String>();
for(int p=0;p<dataSet.get(i).size();++p)
if(p!=index) tempLine.add(dataSet.get(i).get(p));
retDataSet.putLine(tempLine);
}
}
return retDataSet;
}
private String majorityClassificationCount(String[] labels) {
Map<String, Integer> labelCount=new HashMap<String, Integer>();
for(String s : labels) {
if(!labelCount.containsKey(s)) labelCount.put(s,0);
labelCount.put(s,labelCount.get(s)+1);
}
int count=-1;
String t="";
for(String s : labelCount.keySet()) {
if(labelCount.get(s)>count) {
count=labelCount.get(s);
t=s;
}
}
return t;
}
public TreeNode creaDecTree(Matrix_2D<String> dataSet,String[] features) {
final int row=dataSet.getRowDimension(),col=dataSet.getColDimension();
String[] labelsList=new String[row];
for(int i=0;i<row;++i) {
labelsList[i]=dataSet.get(i, col-1);
}
int num=0;
for(String s : labelsList)
if(s.equals(labelsList[0])) ++num;
if(num==labelsList.length) return new TreeNode(labelsList[0],null);//只含一类
if(col==1) return new TreeNode(majorityClassificationCount(labelsList),null);
int bestFeature=chooseBestFeatureToSplit(dataSet);
String bestFeatureLabel=features[bestFeature];
//去掉bestFeature的features
String[] subFeatures=subArray(features, bestFeatureLabel);
Set<String> uniqFeatureVals=new HashSet<String>();//存储值不重复,无序
for(int i=0;i<row;++i) uniqFeatureVals.add(dataSet.get(i).get(bestFeature));
Map<String, TreeNode> child=new HashMap<String, TreeNode>();
for(String s : uniqFeatureVals) {
child.put(s,creaDecTree(splitDataSet(dataSet, bestFeature, s), subFeatures));
}
return new TreeNode(bestFeatureLabel,child);
}
private String[] subArray(String[] original,String str) {
String[] subArray=new String[original.length-1];
int k=0;
for(String s : original) {
if(!s.equals(str)) subArray[k++]=s;
}
return subArray;
}
public String classification(TreeNode tree,String[] features,ArrayList<String> sample) {
while(tree!=null&&tree.getChildren()!=null) {
try {
//System.out.println(tree.element+"\t"+tree.child.size());
tree=tree.getChildren().get(sample.get(getIndex(features, (String)tree.getElement())));
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
return "no such classification";
}
}
if(tree==null) return "no such classification";
return (String)tree.getElement();
}
private int getIndex(String[] arr,String s) {
for(int i=0;i<arr.length;++i)
if(arr[i].equals(s)) return i;
return-1;
}
public static Matrix_2D<String> readDataFile(String path) throws IOException {
ArrayList<ArrayList<String>> trainingSet=new ArrayList<ArrayList<String>>();
File file=new File(path);
if(!file.exists()||!file.isFile()) {
System.out.println(file.getAbsolutePath());
return null;
}
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
String str = "";
while ((str=reader.readLine())!=null) {
String[] tokenizer = str.split(",");
ArrayList<String> s = new ArrayList<String>();
for(int i=0;i<tokenizer.length;i++){
s.add(tokenizer[i]);
}
trainingSet.add(s);
}
reader.close();
//打乱数据集
for(int i=0;i<trainingSet.size();++i) {
int t=(int) ((trainingSet.size()-i)*Math.random());
trainingSet.add(trainingSet.remove(t));
}
return new Matrix_2D<String>(trainingSet);
}
private double report(TreeNode tree,String[] features,Matrix_2D<String> samples) {
int num=0;
for(int i=0;i<samples.getRowDimension();++i) {
if(classification(tree, features, samples.get(i)).equals(samples.get(i,samples.getColDimension()-1)))
++num;
}
return num/(double)samples.getRowDimension();
}
public double reportModel(String[] features,Matrix_2D<String> dataSet,boolean self) {
if(self) return report(creaDecTree(dataSet, features), features, dataSet);
else {//默认1:1
Matrix_2D<String> training=new Matrix_2D<String>();
Matrix_2D<String> test=new Matrix_2D<String>();
for(int i=0;i<dataSet.getRowDimension();++i) {
if(Math.random()>=0.50) training.putLine(dataSet.get(i));
else test.putLine(dataSet.get(i));
}
return report(creaDecTree(training, features), features, test);
}
}
}
C4.5算法选择特征项及测试
```java
package c4_5;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import base.GeneralDecTreeHandler;
import base.Matrix_2D;
import base.TreeNode;
public class C4_5 extends GeneralDecTreeHandler {
@Override
protected int chooseBestFeatureToSplit(Matrix_2D<String> dataSet) {
// TODO 自动生成的方法存根
final int col=dataSet.getColDimension();
double infoGainRation,bestIGR=0.0;
int bestFeature=-1;
for(int i=0;i<col-1;++i) {
infoGainRation=calInfoGainRation(dataSet, i);
if(infoGainRation>bestIGR) {
infoGainRation=bestIGR;
bestFeature=i;
}
}
return bestFeature;
}
private double calInfoGainRation(Matrix_2D<String> ds,int index) {
final int row=ds.getRowDimension();
final double baseEntropy=calShannonEntropy(ds);
Set<String> featureSet=new HashSet<String>();
for(int i=0;i<row;++i) featureSet.add(ds.get(i, index));
double newEntropy=0.0,pro;
for(String s : featureSet) {
Matrix_2D<String> retData=splitDataSet(ds, index, s);
pro=retData.getRowDimension()/(double)row;
newEntropy+=pro*calShannonEntropy(retData);
}
return (baseEntropy-newEntropy)/calSplitInformation(ds, index);
}
private double calSplitInformation(Matrix_2D<String> ds,int index) {
final int m = ds.getRowDimension();
String currentLabel = "";
double splitInfo = 0.0;
double rate = 0;
HashMap<String,Integer> labelCounts = new HashMap<String, Integer>();
//统计各类出现次数
for(int i=0;i<m;i++){
currentLabel = ds.get(i,index);
if(!labelCounts.containsKey(currentLabel))
labelCounts.put(currentLabel,0);
labelCounts.put(currentLabel,labelCounts.get(currentLabel)+1);
}
//计算整体香农熵
for(String key:labelCounts.keySet()){
rate =labelCounts.get(key)/(double)m;
splitInfo -= rate*Math.log(rate)/Math.log(2.0);
}
return splitInfo;
}
public static double calShannonEntropy(Matrix_2D<String> ds) {
int m = ds.getRowDimension();
int n = ds.getColDimension();
String currentLabel = "";
double shannonEnt = 0;
double rate = 0;
HashMap<String,Integer> labelCounts = new HashMap<String, Integer>();
//统计各类出现次数
for(int i=0;i<m;i++){
currentLabel = ds.get(i,n-1);
if(!labelCounts.containsKey(currentLabel))
labelCounts.put(currentLabel,0);
labelCounts.put(currentLabel,labelCounts.get(currentLabel)+1);
}
//计算整体香农熵
for(String key:labelCounts.keySet()){
rate =labelCounts.get(key)/(float)m;
shannonEnt -= rate*Math.log(rate)/Math.log(2);
}
return shannonEnt;
}
public static void main(String[] args) throws IOException {//divorce.txt/
C4_5 tool=new C4_5();//AutismAdultDataPlus.txt/StudentAcademicsPerformance.txt
Matrix_2D<String> trainingSet=C4_5.readDataFile("AutismAdultDataPlus.txt");
String[] features=new String[trainingSet.getColDimension()-1];
for(int i=0;i<features.length;++i)
features[i]="特征"+String.valueOf(i);
TreeNode tree=tool.creaDecTree(trainingSet, features);
int num=0;
final int row=trainingSet.getRowDimension(),col=trainingSet.getColDimension();
for(int i=0;i<row;++i) {
String tmp=tool.classification(tree, features, trainingSet.get(i));
if(tmp.equals(trainingSet.get(i).get(col-1))) {
++num;
}
}
System.out.println("测试实例数:"+row+",分类正确数:"+num+",分类精度:"+(num/(double)row));
System.out.println("*30次1:1模型测试:");
double pp=0.0,p0;
for(int i=1;i<=30;++i) {
System.out.print("第"+i+"次:");
p0=tool.reportModel(features, trainingSet, false);
pp+=p0;
System.out.println(p0);
}
System.out.println("30次均值:"+(pp/30));
}
}
节点类:TreeNode.java
package base;
import java.util.Map;
public class TreeNode {
private String element;
private Map<String, TreeNode> children;
public TreeNode() {
// TODO 自动生成的构造函数存根
}
public TreeNode(String e,Map<String, TreeNode> c) {
// TODO 自动生成的构造函数存根
element=e;
children=c;
}
public Map<String, TreeNode> getChildren() {
return children;
}
public String getElement() {
return element;
}
}
工具类:Matrix_2D.java
package base;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
public class Matrix_2D<T> {
ArrayList<ArrayList<T>> data;
public Matrix_2D() {
// TODO 自动生成的构造函数存根
data=new ArrayList<ArrayList<T>>();
}
public Matrix_2D(ArrayList<ArrayList<T>> d) {
data=new ArrayList<ArrayList<T>>();
for(ArrayList<T> val : d)
this.putLine(val);
}
public void putLine(ArrayList<T> line) {
ArrayList<T> tmp=new ArrayList<T>();
for(T t : line) tmp.add(t);
data.add(tmp);
}
public int getRowDimension() {
return data.size();
}
public int getColDimension() {
return data.get(0).size();
}
public ArrayList<T> get(int i) {
return data.get(i);
}
public T get(int i,int j) {
return data.get(i).get(j);
}
public T remove(int i,int j) {
return data.get(i).remove(j);
}
public ArrayList<T> remove(int index) {
return data.remove(index);
}
public static String[] subArray(String[] original,String str) {
String[] subArray=new String[original.length-1];
int k=0;
for(String s : original) {
if(!s.equals(str)) subArray[k++]=s;
}
return subArray;
}
public static ArrayList<String> copyArrayList(ArrayList<String> data) {
ArrayList<String> d=new ArrayList<String>();
for(String s : data) d.add(s);
return d;
}
public static String majority(ArrayList<String> labels) {
Map<String, Integer> labelCount=new HashMap<String, Integer>();
for(String s : labels) {
if(!labelCount.containsKey(s)) labelCount.put(s,0);
labelCount.put(s,labelCount.get(s)+1);
}
int count=-1;
String t="";
for(String s : labelCount.keySet()) {
if(labelCount.get(s)>count) {
count=labelCount.get(s);
t=s;
}
}
return t;
}
}
参考文章:
机器学习算法:18大数据挖掘的经典算法以及代码Java实现
ID3、C4.5算法介绍以及java代码实现
归纳决策树ID3的实现