感知机 java

/**
 * 感知机 解决二分类问题 1,-1
 * @author ysh 1208706282
 *
 */
public class Perceptron {
    double weight[];
    List<Sample> samples;
    static class Sample{
        Double label;
        List<Double> feature;
    }
    
    public  void loadData(String path,String regex) throws Exception{
        samples = new ArrayList<Sample>();
        BufferedReader reader = new BufferedReader(new FileReader(path));
        String line = null;
        String splits[] = null;
        Sample sample = null;
        while(null != (line=reader.readLine())){
            splits = line.split(regex);
            sample = new Sample();
            sample.feature = new ArrayList<Double>(splits.length-1);
            for(int i=0;i<splits.length-1;i++){
                sample.feature.add(new Double(splits[i]));
            }
            sample.label = Double.valueOf(splits[splits.length-1]);
            if(sample.label == 0){   //标签为0改为-1
                sample.label = -1.0;
            }
            samples.add(sample);
        }
        reader.close();
    }
    public double classify(Sample sample,double weight[]){
        double ret = 0;
        for(int i=0;i<sample.feature.size();i++){
            ret += sample.feature.get(i)*weight[i];
        }
        ret += weight[weight.length-1];//偏置
        return ret;
    }
    public void updateWeight(Sample sample,double weight[],double eta){
        for(int i=0;i<sample.feature.size();i++){
            weight[i] += eta*sample.label*sample.feature.get(i);
        }
        weight[weight.length-1] += eta*sample.label;
    }
    public void train(int iters,double eta){
        int len = samples.get(0).feature.size();
        weight = new double[len+1];
        for(int i=0;i<weight.length;i++){
            weight[i] = 0;
        }
        for(int iter=0;iter<iters;iter++){
            int count = 0;
            for(Sample sample:samples){
                if(sample.label*classify(sample,weight) <= 0){
                    updateWeight(sample,weight,eta);
                    count++;
                }
            }
            if(count == 0){
                System.out.println("already complete");
                break;
            }
            System.out.println("iter "+iter+" count "+count);
        }
    }
    
    public void test(){
        int count = 0;
        for(Sample sample:samples){
            double value = classify(sample,weight);
            System.out.println(value+","+sample.label);
            if(sample.label>0){
                if(value>=0){
                    count++;
                }
            }else{
                if(value<0){
                    count++;
                }
            }
        }
        
        System.out.println("right rate: "+count*1.0/samples.size());
    }
    /**
     * @param args
     * @throws Exception
     */
    public static void main(String[] args) throws Exception {
        // TODO Auto-generated method stub
        Perceptron per = new Perceptron();
        per.loadData("F:/contest/iris.csv",",");
        per.train(100,0.1);
        per.test();
    }
}

猜你喜欢

转载自blog.csdn.net/ysh126/article/details/53073703