大家都知道,机器学习算法用python的契合度是最高的,但很多公司的在售产品是用java写,想集成机器学习算法,又不能把前面的东西都推到重写,所以在java中调python方法迫在眉睫,所以本方案就是通过java调python程序导出模型,并可以随时通过java导入模型,应用其场景。我是小順,请大家关注我,我会给大家发更多的工具。
prop.properties
modelUrl=/soc/
trainDataUrl=/soc/data_log/
predictDataUrl=/soc/predict_test_log/
sciptsUrl=/soc/
python3Url=/root/anaconda3/bin/python3
predictSavePath=/soc/
modelName1=mnb_pipeline
JavaExecPythonMLModel.java
package com.eduTwo;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.text.SimpleDateFormat;
import java.util.Properties;
public class JavaExecPythonMLModel {
/**
* 全局静态变量
* Global static variable
*/
public static String modelUrl;
public static String trainDataUrl;
public static String predictDataUrl;
public static String sciptsUrl;
public static String python3Url;
public static String modelName1;
public static String predictSavePath;
static{
InputStream in = null;
try {
//创建Properties对象 Create a Properties object
Properties prop = new Properties();
//读取属性文件prop.properties Read the property file prop.properties)
in = JavaExecPythonMLModel.class.getResourceAsStream("prop.properties");
//加载属性列表 Load attribute list
prop.load(new InputStreamReader(in, "utf-8"));
//读取属性列表 Read attribute list
modelUrl = prop.getProperty("modelUrl");
trainDataUrl = prop.getProperty("trainDataUrl");
predictDataUrl = prop.getProperty("predictDataUrl");
sciptsUrl = prop.getProperty("sciptsUrl");
python3Url = prop.getProperty("python3Url");
modelName1 = prop.getProperty("modelName1");
predictSavePath = prop.getProperty("predictSavePath");
} catch (Exception e) {
e.printStackTrace();
}
}
public static void executepy(String[] args)throws IOException, InterruptedException{
try {
//定义process对象 Define process object
Process proc = Runtime.getRuntime().exec(args);
//获取输出流 Get the output stream
BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream()));
//定义变量准备接收返回结果 Define variables ready to receive return results
String line = null;
//循环输出结果 Loop output result
while ((line = in.readLine()) != null) {
System.out.println(line);
//当遇到某种输出退出循环 When encountering some kind of output exit loop
if("success".equals(in.readLine())||"fail".equals(in.readLine())){
System.out.println(line);
in.close();
break;
}
}
/**
* 获取脚本执行返回码 若返回0 证明执行成功 若返回1 证明未能执行 其他编码可以自查
* Get the script execution return code. If it returns 0, it proves that the execution is successful.
* If it returns 1, it proves that it failed to execute. Other codes can be self-checked.
*/
int r = proc.waitFor();
System.out.println("end.................return code:"+r);
} catch (Exception e) {
e.printStackTrace();
}
}
public static void main(String[] args) {
try {
/**
* 定义脚本输入源数组:python3源,python脚本路径,python需要的变量(若干)
*
* Define the script input source array:
* python3 source, python script path, python required variables (several)
*
*/
String[] args1 = new String[] { python3Url, sciptsUrl+"firstMethod.py", trainDataUrl, modelUrl, modelName1 };
String[] args2 = new String[] { python3Url, sciptsUrl+"twoMethod.py", modelUrl+modelName1, predictDataUrl,trainDataUrl,predictSavePath };
SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
String date1 = df.format(System.currentTimeMillis());
System.out.println(date1);
/**
* 把数组传给那个函数即可
* Pass the array to that function.
*/
executepy(args1);
executepy(args2);
String date2 = df.format(System.currentTimeMillis());
System.out.println(date2);
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
}
}
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 7 13:52:42 2018
@author:
"""
from sklearn.datasets import load_files
from sklearn2pmml import PMMLPipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.externals import joblib
import os
import sys
def getFirstContent(dataUrl,modelUrl,modelName):
training_data = load_files(dataUrl,encoding="utf-8")
'''
这是开始提取特征,这里的特征是词频统计。
'''
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(training_data.data)
'''
这是开始提取特征,这里的特征是TFIDF特征。
'''
tfidf_transformer = TfidfTransformer()
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
'''
使用朴素贝叶斯分类,并做出简单的预测
'''
mnb_pipeline = PMMLPipeline([("classifier", LogisticRegression())])
mnb_pipeline.fit(X_train_tfidf, training_data.target)
joblib.dump(mnb_pipeline, modelUrl+modelName)
if(os.path.exists(modelUrl+modelName)):
return "success";
else:
return "fail";
if __name__ == '__main__':
a = []
for i in range(1, len(sys.argv)):
a.append((str(sys.argv[i])))
print(getFirstContent(a[0],a[1],a[2]))
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 7 09:17:46 2018
@author:
"""
from sklearn.datasets import load_files
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.externals import joblib
import sys
import os
import numpy
def getTwoContent(mnb_pipeline,file,dataUrl,predictSavePath):
mnb_pipeline=joblib.load(mnb_pipeline)
testing_data = load_files(file,encoding="utf-8")
training_data = load_files(dataUrl,encoding="utf-8")
docs_test = testing_data.data
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(training_data.data)
X_test_counts = count_vect.transform(docs_test)
tfidf_transformer = TfidfTransformer()
tfidf_transformer.fit_transform(X_train_counts)
X_test_tfidf = tfidf_transformer.transform(X_test_counts)
predicted = mnb_pipeline.predict(X_test_tfidf)
numpy.savetxt(predictSavePath+'new.csv', predicted, delimiter = ',')
if(os.path.exists(predictSavePath+'new.csv')):
return "success";
else:
return "fail";
if __name__ == '__main__':
a = []
for i in range(1, len(sys.argv)):
a.append((str(sys.argv[i])))
print(getTwoContent(a[0],a[1],a[2],a[3]))
部署到linux中去执行,执行结果如下: