alink java 版 ftrl 例子

官方只有python版的,先把代码贴一下,后续分析此例子


package com.ziroom.ml2;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.dataproc.JsonValueStreamOp;
import com.alibaba.alink.operator.stream.dataproc.SplitStreamOp;
import com.alibaba.alink.operator.stream.evaluation.EvalBinaryClassStreamOp;
import com.alibaba.alink.operator.stream.onlinelearning.FtrlPredictStreamOp;
import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp;
import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.dataproc.StandardScaler;
import com.alibaba.alink.pipeline.feature.FeatureHasher;

public class Ftrl {

    public static void main(String[] args) throws Exception {
        CsvSourceBatchOp op=new CsvSourceBatchOp();
        String schemaStr = "id string, click string, dt string, C1 string, banner_pos int, "
                + "site_id string,  site_domain string, site_category string, app_id string, app_domain string,   "
                + "app_category string, device_id string, device_ip string, device_model string,  "
                + " device_type string, device_conn_type string, C14 int, C15 int, C16 int, C17 int,  "
                + " C18 int, C19 int, C20 int, C21 int";

        BatchOperator trainBatchData= op.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv")
        .setSchemaStr(schemaStr);

        String[] selectedColNames =new String[] {"C1","banner_pos","site_category","app_domain",
                           "app_category","device_type","device_conn_type", 
                           "C14","C15","C16","C17","C18","C19","C20","C21",
                            "site_id","site_domain","device_id","device_model"};

        String[] categoryColNames = new String[] {"C1","banner_pos","site_category","app_domain", 
                            "app_category","device_type","device_conn_type",
                            "site_id","site_domain","device_id","device_model"};

        String[] numericalColNames = new String[] {"C14","C15","C16","C17","C18","C19","C20","C21"};

        //feature fit
        String labelColName = "click";
        String vecColName = "vec";
        //这里设置小是为了单机也能跑成功 官方例子中此值比较大
        int numHashFeatures = 30;

        Pipeline feature_pipeline = new Pipeline()
                .add(new StandardScaler()
                        .setSelectedCols(numericalColNames)) 
                .add(new FeatureHasher()
                        .setSelectedCols(selectedColNames)
                        .setCategoricalCols(categoryColNames)
                        .setOutputCol(vecColName)
                        .setNumFeatures(numHashFeatures));

        //fit pipeline model
        PipelineModel feature_pipelineModel = feature_pipeline.fit(trainBatchData);
        //prepare stream train data
        String wholeDataFile = "http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-ctr-train-8M.csv";
        StreamOperator data = new CsvSourceStreamOp()
                .setFilePath(wholeDataFile) 
                .setSchemaStr(schemaStr) 
                .setIgnoreFirstLine(true);

        //split stream to train and eval data
        StreamOperator spliter =new SplitStreamOp().setFraction(0.5).linkFrom(data);
        StreamOperator train_stream_data = spliter;
        StreamOperator test_stream_data = spliter.getSideOutput(0);

        LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp();
        LogisticRegressionTrainBatchOp initModel = lr.setVectorCol(vecColName)
                        .setLabelCol(labelColName)
                        .setWithIntercept(true)
                        .setMaxIter(1)
                        .linkFrom(feature_pipelineModel.transform(trainBatchData));

        FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel)
                .setVectorCol(vecColName)
                .setLabelCol(labelColName)
                .setWithIntercept(true)
                .setAlpha(0.1)
                .setBeta(0.1)
                .setL1(0.01)
                .setL2(0.01)
                .setTimeInterval(10)
                .setVectorSize(numHashFeatures)
                .linkFrom(feature_pipelineModel.transform(train_stream_data));

        FtrlPredictStreamOp predResult = new  FtrlPredictStreamOp(initModel)
                .setVectorCol(vecColName)
                .setPredictionCol("pred")
                .setReservedCols(new String[] { labelColName })
                .setPredictionDetailCol("details")
                .linkFrom(model, feature_pipelineModel.transform(test_stream_data));

        predResult.print(30, 20);

        new EvalBinaryClassStreamOp()
        .setLabelCol(labelColName)
        .setPredictionCol("pred")
        .setPredictionDetailCol("details") 
        .setTimeInterval(10) 
        .linkFrom(predResult) 
        .link(new JsonValueStreamOp() 
                .setSelectedCol("Data") 
                .setReservedCols(new String[] {"Statistics"}) 
                .setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"}) 
                .setJsonPath(new String[] {"$.Accuracy", "$.AUC", "$.ConfusionMatrix"})) 
                .print(30, 20);

        StreamOperator.execute();

    }
}

结果可能打印在其他的日志文件中,请注意

猜你喜欢

转载自blog.51cto.com/12597095/2463716