【终极版】万能java执行python方法生成机器学习模型并应用机器学习模型

大家都知道,机器学习算法用python的契合度是最高的,但很多公司的在售产品是用java写,想集成机器学习算法,又不能把前面的东西都推到重写,所以在java中调python方法迫在眉睫,所以本方案就是通过java调python程序导出模型,并可以随时通过java导入模型,应用其场景。我是小順,请大家关注我,我会给大家发更多的工具。

prop.properties

modelUrl=/soc/
trainDataUrl=/soc/data_log/
predictDataUrl=/soc/predict_test_log/
sciptsUrl=/soc/
python3Url=/root/anaconda3/bin/python3
predictSavePath=/soc/
modelName1=mnb_pipeline

JavaExecPythonMLModel.java

package com.eduTwo;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.text.SimpleDateFormat;
import java.util.Properties;

public class JavaExecPythonMLModel {
    
    /**
     * 全局静态变量 
     * Global static variable
     */
    public static String modelUrl;
    public static String trainDataUrl;
    public static String predictDataUrl;
    public static String sciptsUrl;
    public static String python3Url;
    public static String modelName1;
    public static String predictSavePath;
    
    static{
        
        InputStream in = null;
        try {
            
            //创建Properties对象  Create a Properties object
            Properties prop = new Properties();
            //读取属性文件prop.properties Read the property file prop.properties)
            in = JavaExecPythonMLModel.class.getResourceAsStream("prop.properties");
            //加载属性列表  Load attribute list
            prop.load(new InputStreamReader(in, "utf-8"));
            //读取属性列表  Read attribute list
            modelUrl = prop.getProperty("modelUrl");
            trainDataUrl = prop.getProperty("trainDataUrl");
            predictDataUrl = prop.getProperty("predictDataUrl");
            sciptsUrl = prop.getProperty("sciptsUrl");
            python3Url = prop.getProperty("python3Url");
            modelName1 = prop.getProperty("modelName1");
            predictSavePath = prop.getProperty("predictSavePath");

        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    
    public static void executepy(String[] args)throws IOException, InterruptedException{
        
        
        try {
            //定义process对象  Define process object
            Process proc = Runtime.getRuntime().exec(args);
            //获取输出流  Get the output stream
            BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream()));
            //定义变量准备接收返回结果 Define variables ready to receive return results
            String line = null;
            //循环输出结果 Loop output result
            while ((line = in.readLine()) != null) {
                System.out.println(line);
                //当遇到某种输出退出循环 When encountering some kind of output exit loop
                if("success".equals(in.readLine())||"fail".equals(in.readLine())){
                    System.out.println(line);
                    in.close();
                    break;
                }
                
            }
            /**
             * 获取脚本执行返回码 若返回0 证明执行成功 若返回1 证明未能执行 其他编码可以自查
             * Get the script execution return code. If it returns 0, it proves that the execution is successful. 
             * If it returns 1, it proves that it failed to execute. Other codes can be self-checked.
             */
            int r = proc.waitFor();
            System.out.println("end.................return code:"+r);
            
            
        } catch (Exception e) {
            e.printStackTrace();
        }

    }

    public static void main(String[] args) {
        try {
            /**
             * 定义脚本输入源数组:python3源,python脚本路径,python需要的变量(若干)
             * 
             * Define the script input source array: 
             * python3 source, python script path, python required variables (several)
             * 
             */
            String[] args1 = new String[] { python3Url, sciptsUrl+"firstMethod.py", trainDataUrl, modelUrl, modelName1 };
            String[] args2 = new String[] { python3Url, sciptsUrl+"twoMethod.py", modelUrl+modelName1, predictDataUrl,trainDataUrl,predictSavePath };
            
            SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
            String date1 = df.format(System.currentTimeMillis());
            System.out.println(date1);
            
            /**
             * 把数组传给那个函数即可
             * Pass the array to that function.
             */
            executepy(args1);
            executepy(args2);
            
            String date2 = df.format(System.currentTimeMillis());
            System.out.println(date2);
            
        } catch (IOException | InterruptedException e) {
            e.printStackTrace();
        }
    }
}

firstMethod.py

# -*- coding: utf-8 -*-

"""
Created on Fri Dec  7 13:52:42 2018

@author:
"""
from sklearn.datasets import load_files
from sklearn2pmml import PMMLPipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.externals import joblib
import os
import sys


def getFirstContent(dataUrl,modelUrl,modelName):
    
    
    training_data = load_files(dataUrl,encoding="utf-8")
    '''
    这是开始提取特征,这里的特征是词频统计。
    '''
    count_vect = CountVectorizer()
    
    X_train_counts = count_vect.fit_transform(training_data.data)
    
    '''
    这是开始提取特征,这里的特征是TFIDF特征。
    '''
    tfidf_transformer = TfidfTransformer()
    
    X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
 
    '''
    使用朴素贝叶斯分类,并做出简单的预测
    '''
    mnb_pipeline = PMMLPipeline([("classifier", LogisticRegression())])
   
    mnb_pipeline.fit(X_train_tfidf, training_data.target)
 
    joblib.dump(mnb_pipeline, modelUrl+modelName)
    
    
    if(os.path.exists(modelUrl+modelName)):
 
        return "success";
    else:

        return "fail";
    
    

if __name__ == '__main__':
    a = []
    for i in range(1, len(sys.argv)):
        a.append((str(sys.argv[i])))
        
    print(getFirstContent(a[0],a[1],a[2]))

secondMethod.py

# -*- coding: utf-8 -*-
"""
Created on Fri Dec  7 09:17:46 2018

@author: 
"""


from sklearn.datasets import load_files
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.externals import joblib
import sys
import os
import numpy 

def getTwoContent(mnb_pipeline,file,dataUrl,predictSavePath):
   
    mnb_pipeline=joblib.load(mnb_pipeline)
    
    testing_data = load_files(file,encoding="utf-8")
    
    training_data = load_files(dataUrl,encoding="utf-8")
    
    docs_test = testing_data.data
    
    count_vect = CountVectorizer()
    
    X_train_counts = count_vect.fit_transform(training_data.data)
    
    X_test_counts = count_vect.transform(docs_test)

    tfidf_transformer = TfidfTransformer()
    
    tfidf_transformer.fit_transform(X_train_counts)
    
    X_test_tfidf = tfidf_transformer.transform(X_test_counts)
    
    predicted = mnb_pipeline.predict(X_test_tfidf)
    
    numpy.savetxt(predictSavePath+'new.csv', predicted, delimiter = ',') 
    
    if(os.path.exists(predictSavePath+'new.csv')):

        return "success";
    else:

        return "fail";


if __name__ == '__main__':
    a = []
    for i in range(1, len(sys.argv)):
        a.append((str(sys.argv[i])))
        
    print(getTwoContent(a[0],a[1],a[2],a[3]))


部署到linux中去执行,执行结果如下:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/walteryonng20xx/article/details/85059360