<pre name="code" class="java">package com.haolidong.Logistic;
import java.util.ArrayList;
/**
*
* @author haolidong
* @Description: [该类主要用于保存特征信息]
* @parameter data: [主要保存特征矩阵]
*/
public class Matrix {
<span style="white-space:pre"> </span>public ArrayList<ArrayList<String>> data;
<span style="white-space:pre"> </span>public Matrix() {
<span style="white-space:pre"> </span>// TODO Auto-generated constructor stub
<span style="white-space:pre"> </span>data = new ArrayList<ArrayList<String>>();
<span style="white-space:pre"> </span>}
}
<pre name="code" class="java">package com.haolidong.Logistic;
import java.util.ArrayList;
/**
*
* @author haolidong
* @Description: [该类主要用于保存特征信息以及标签值]
* @parameter labels: [主要保存标签值]
*/
public class CreateDataSet extends Matrix {
public ArrayList<String> labels;
public CreateDataSet() {
// TODO Auto-generated constructor stub
super();
labels = new ArrayList<String>();
}
/**
* @author haolidong
* @Description: [机器学习实战逻辑回归第一个案例的数据]
*/
public void initTest() {
}
}
package com.haolidong.Logistic;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
public class Logistic {
public static void main(String[] args) {
colicTest();
}
/**
* @author haolidong
* @Description: [逻辑回归的简单测试]
*/
public static void LogisticTest() {
// TODO Auto-generated method stub
CreateDataSet dataSet = new CreateDataSet();
dataSet = readFile("I:\\machinelearninginaction\\Ch05\\testSet.txt");
ArrayList<Double> weights = new ArrayList<Double>();
weights = gradAscent1(dataSet, dataSet.labels, 150);
for (int i = 0; i < 3; i++) {
System.out.println(weights.get(i));
}
System.out.println();
}
/**
* @param inX
* @param weights
* @return
* @author haolidong
* @Description: [sigmod分类]
*/
public static String classifyVector(ArrayList<String> inX, ArrayList<Double> weights) {
ArrayList<Double> sum = new ArrayList<>();
sum.clear();
sum.add(0.0);
for (int i = 0; i < inX.size(); i++) {
sum.set(0, sum.get(0) + Double.parseDouble(inX.get(i)) * weights.get(i));
}
if (sigmoid(sum).get(0) > 0.5)
return "1";
else
return "0";
}
/**
* @author haolidong
* @Description: [预测马的疝气病的死亡率]
*/
public static void colicTest() {
CreateDataSet trainingSet = new CreateDataSet();
CreateDataSet testSet = new CreateDataSet();
trainingSet = readFile("I:\\machinelearninginaction\\Ch05\\horseColicTraining.txt");
testSet = readFile("I:\\machinelearninginaction\\Ch05\\horseColicTest.txt");
ArrayList<Double> weights = new ArrayList<Double>();
weights = gradAscent1(trainingSet, trainingSet.labels, 500);
int errorCount = 0;
for (int i = 0; i < testSet.data.size(); i++) {
if (!classifyVector(testSet.data.get(i), weights).equals(testSet.labels.get(i))) {
errorCount++;
}
System.out.println(classifyVector(testSet.data.get(i), weights) + "," + testSet.labels.get(i));
}
System.out.println(1.0 * errorCount / testSet.data.size());
}
/**
* @param inX
* @return
* @author haolidong
* @Description: [sigmod函数]
*/
public static ArrayList<Double> sigmoid(ArrayList<Double> inX) {
ArrayList<Double> inXExp = new ArrayList<Double>();
for (int i = 0; i < inX.size(); i++) {
inXExp.add(1.0 / (1 + Math.exp(-inX.get(i))));
}
return inXExp;
}
/**
* @param dataSet
* @param classLabels
* @param numberIter
* @return
* @author haolidong
* @Description: [改进的随机梯度上升算法]
*/
public static ArrayList<Double> gradAscent1(Matrix dataSet, ArrayList<String> classLabels, int numberIter) {
int m = dataSet.data.size();
int n = dataSet.data.get(0).size();
double alpha = 0.0;
int randIndex = 0;
ArrayList<Double> weights = new ArrayList<Double>();
ArrayList<Double> weightstmp = new ArrayList<Double>();
ArrayList<Double> h = new ArrayList<Double>();
ArrayList<Integer> dataIndex = new ArrayList<Integer>();
ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
for (int i = 0; i < n; i++) {
weights.add(1.0);
weightstmp.add(1.0);
}
dataMatrixMulweights.add(0.0);
double error = 0.0;
for (int j = 0; j < numberIter; j++) {
// 产生0到99的数组
for (int p = 0; p < m; p++) {
dataIndex.add(p);
}
// 进行每一次的训练
for (int i = 0; i < m; i++) {
alpha = 4 / (1.0 + i + j) + 0.0001;
randIndex = (int) (Math.random() * dataIndex.size());
dataIndex.remove(randIndex);
double temp = 0.0;
for (int k = 0; k < n; k++) {
temp = temp + Double.parseDouble(dataSet.data.get(randIndex).get(k)) * weights.get(k);
}
dataMatrixMulweights.set(0, temp);
h = sigmoid(dataMatrixMulweights);
error = Double.parseDouble(classLabels.get(randIndex)) - h.get(0);
double tempweight = 0.0;
for (int p = 0; p < n; p++) {
tempweight = alpha * Double.parseDouble(dataSet.data.get(randIndex).get(p)) * error;
weights.set(p, weights.get(p) + tempweight);
}
}
}
return weights;
}
/**
* @param dataSet
* @param classLabels
* @return
* @author haolidong
* @Description: [随机梯度上升算法]
*/
public static ArrayList<Double> gradAscent0(Matrix dataSet, ArrayList<String> classLabels) {
int m = dataSet.data.size();
int n = dataSet.data.get(0).size();
ArrayList<Double> weights = new ArrayList<Double>();
ArrayList<Double> weightstmp = new ArrayList<Double>();
ArrayList<Double> h = new ArrayList<Double>();
double error = 0.0;
ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
double alpha = 0.01;
for (int i = 0; i < n; i++) {
weights.add(1.0);
weightstmp.add(1.0);
}
h.add(0.0);
double temp = 0.0;
dataMatrixMulweights.add(0.0);
for (int i = 0; i < m; i++) {
temp = 0.0;
for (int k = 0; k < n; k++) {
temp = temp + Double.parseDouble(dataSet.data.get(i).get(k)) * weights.get(k);
}
dataMatrixMulweights.set(0, temp);
h = sigmoid(dataMatrixMulweights);
error = Double.parseDouble(classLabels.get(i)) - h.get(0);
double tempweight = 0.0;
for (int p = 0; p < n; p++) {
tempweight = alpha * Double.parseDouble(dataSet.data.get(i).get(p)) * error;
weights.set(p, weights.get(p) + tempweight);
}
}
return weights;
}
/**
* @param dataSet
* @param classLabels
* @return
* @author haolidong
* @Description: [全部数据的梯度上升算法]
*/
public static ArrayList<Double> gradAscent(Matrix dataSet, ArrayList<String> classLabels) {
int m = dataSet.data.size();
int n = dataSet.data.get(0).size();
ArrayList<Double> weights = new ArrayList<Double>();
ArrayList<Double> weightstmp = new ArrayList<Double>();
ArrayList<Double> h = new ArrayList<Double>();
ArrayList<Double> error = new ArrayList<Double>();
ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
double alpha = 0.001;
int maxCycles = 500;
for (int i = 0; i < n; i++) {
weights.add(1.0);
weightstmp.add(1.0);
}
for (int i = 0; i < m; i++) {
h.add(0.0);
error.add(0.0);
dataMatrixMulweights.add(0.0);
}
double temp;
for (int i = 0; i < maxCycles; i++) {
for (int j = 0; j < m; j++) {
temp = 0.0;
for (int k = 0; k < n; k++) {
temp = temp + Double.parseDouble(dataSet.data.get(j).get(k)) * weights.get(k);
}
dataMatrixMulweights.set(j, temp);
}
h = sigmoid(dataMatrixMulweights);
for (int q = 0; q < m; q++) {
error.set(q, Double.parseDouble(classLabels.get(q)) - h.get(q));
}
double tempweight = 0.0;
for (int p = 0; p < n; p++) {
tempweight = 0.0;
for (int q = 0; q < m; q++) {
tempweight = tempweight + alpha * Double.parseDouble(dataSet.data.get(q).get(p)) * error.get(q);
}
weights.set(p, weights.get(p) + tempweight);
}
}
return weights;
}
/**
* @param fileName
* 读入的文件名
* @return
* @author haolidong
* @Description: [根据读入的文件名形成特征集以及标签]
*/
public static CreateDataSet readFile(String fileName) {
File file = new File(fileName);
BufferedReader reader = null;
CreateDataSet dataSet = new CreateDataSet();
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
// 一次读入一行,直到读入null为文件结束
while ((tempString = reader.readLine()) != null) {
// 显示行号
String[] strArr = tempString.split("\t");
ArrayList<String> as = new ArrayList<String>();
as.add("1");
for (int i = 0; i < strArr.length - 1; i++) {
as.add(strArr[i]);
}
dataSet.data.add(as);
dataSet.labels.add(strArr[strArr.length - 1]);
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return dataSet;
}
}