深度学习 之 DeepLearning4j 预测股市走向

上一篇,预测花的类型,是没有用到中间件的,实际情况是,数据量是非常大的,所以不实用,这次使用DeepLearning4j来预测股市走向,后续加上spark。代码如下:

public class DailyData {
  //开盘价  
    private double openPrice;  
    //收盘价  
    private double closeprice;  
    //最高价  
    private double maxPrice;  
    //最低价  
    private double minPrice;  
    //成交量  
    private double turnover;  
    //成交额  
    private double volume;  

    public double getTurnover() {  

        return turnover;  
    }  

    public double getVolume() {  
        return volume;  
    }  

    public DailyData(){  

    }  

    public double getOpenPrice() {  
        return openPrice;  
    }  

    public double getCloseprice() {  
        return closeprice;  
    }  

    public double getMaxPrice() {  
        return maxPrice;  
    }  

    public double getMinPrice() {  
        return minPrice;  
    }  

    public void setOpenPrice(double openPrice) {  
        this.openPrice = openPrice;  
    }  

    public void setCloseprice(double closeprice) {  
        this.closeprice = closeprice;  
    }  

    public void setMaxPrice(double maxPrice) {  
        this.maxPrice = maxPrice;  
    }  

    public void setMinPrice(double minPrice) {  
        this.minPrice = minPrice;  
    }  

    public void setTurnover(double turnover) {  
        this.turnover = turnover;  
    }  

    public void setVolume(double volume) {  
        this.volume = volume;  
    }  

    @Override  
    public String toString(){  
        StringBuilder builder = new StringBuilder();  
        builder.append("开盘价="+this.openPrice+", ");  
        builder.append("收盘价="+this.closeprice+", ");  
        builder.append("最高价="+this.maxPrice+", ");  
        builder.append("最低价="+this.minPrice+", ");  
        builder.append("成交量="+this.turnover+", ");  
        builder.append("成交额="+this.volume);  
        return builder.toString();  
    }  
}
public class StockDataIterator implements DataSetIterator {  

    /**
     * 
     */
    private static final long serialVersionUID = 1L;
    private static final int VECTOR_SIZE = 6;  
    //每批次的训练数据组数  
    private int batchNum;  

    //每组训练数据长度(DailyData的个数)  
    private int exampleLength;  

    //数据集  
    private List<DailyData> dataList;  

    //存放剩余数据组的index信息  
    private List<Integer> dataRecord;  

    private double[] maxNum;  
    /** 
     * 构造方法 
     * */  
    public StockDataIterator(){  
        dataRecord = new ArrayList<>();  
    }  

    /** 
     * 加载数据并初始化 
     * */  
    public boolean loadData(String fileName, int batchNum, int exampleLength){  
        this.batchNum = batchNum;  
        this.exampleLength = exampleLength;  
        maxNum = new double[6];  
        //加载文件中的股票数据  
        try {  
            readDataFromFile(fileName);  
        }catch (Exception e){  
            e.printStackTrace();  
            return false;  
        }  
        //重置训练批次列表  
        resetDataRecord();  
        return true;  
    }  

    /** 
     * 重置训练批次列表 
     * */  
    private void resetDataRecord(){  
        dataRecord.clear();  
        int total = dataList.size()/exampleLength+1;  
        for( int i=0; i<total; i++ ){  
            dataRecord.add(i * exampleLength);  
        }  
    }  

    /** 
     * 从文件中读取股票数据 
     * */  
    public List<DailyData> readDataFromFile(String fileName) throws IOException{  
        dataList = new ArrayList<>();  
        BufferedReader in = new BufferedReader(new InputStreamReader(StockDataIterator.class.getResourceAsStream(fileName) ,"UTF-8"));  
        String line = in.readLine();  
        for(int i=0;i<maxNum.length;i++){  
            maxNum[i] = 0;  
        }  
        System.out.println("读取数据..");  
        while(line!=null){  
            String[] strArr = line.split(",");  
            if(strArr.length>=7) {  
                DailyData data = new DailyData();  
                //获得最大值信息,用于归一化  
                double[] nums = new double[6];  
                for(int j=0;j<6;j++){  
                    nums[j] = Double.valueOf(strArr[j+2]);  
                    if( nums[j]>maxNum[j] ){  
                        maxNum[j] = nums[j];  
                    }  
                }  
                //构造data对象  
                data.setOpenPrice(Double.valueOf(nums[0]));  
                data.setCloseprice(Double.valueOf(nums[1]));  
                data.setMaxPrice(Double.valueOf(nums[2]));  
                data.setMinPrice(Double.valueOf(nums[3]));  
                data.setTurnover(Double.valueOf(nums[4]));  
                data.setVolume(Double.valueOf(nums[5]));  
                dataList.add(data);  

            }  
            line = in.readLine();  
        }  
        in.close();  
        System.out.println("反转list...");  
        Collections.reverse(dataList);  
        return dataList;  
    }  

    public double[] getMaxArr(){  
        return this.maxNum;  
    }  

    public void reset(){  
        resetDataRecord();  
    }  

    public boolean hasNext(){  
        return dataRecord.size() > 0;  
    }  

    public DataSet next(){  
        return next(batchNum);  
    }  

    /** 
     * 获得接下来一次的训练数据集 
     * */  
    public DataSet next(int num){  
        if( dataRecord.size() <= 0 ) {  
            throw new NoSuchElementException();  
        }  
        int actualBatchSize = Math.min(num, dataRecord.size());  
        int actualLength = Math.min(exampleLength,dataList.size()-dataRecord.get(0)-1);  
        INDArray input = Nd4j.create(new int[]{actualBatchSize,VECTOR_SIZE,actualLength}, 'f');  
        INDArray label = Nd4j.create(new int[]{actualBatchSize,1,actualLength}, 'f');  
        DailyData nextData = null,curData = null;  
        //获取每批次的训练数据和标签数据  
        for(int i=0;i<actualBatchSize;i++){  
            int index = dataRecord.remove(0);  
            int endIndex = Math.min(index+exampleLength,dataList.size()-1);  
            curData = dataList.get(index);  
            for(int j=index;j<endIndex;j++){  
                //获取数据信息  
                nextData = dataList.get(j+1);  
                //构造训练向量  
                int c = endIndex-j-1;  
                input.putScalar(new int[]{i, 0, c}, curData.getOpenPrice()/maxNum[0]);  
                input.putScalar(new int[]{i, 1, c}, curData.getCloseprice()/maxNum[1]);  
                input.putScalar(new int[]{i, 2, c}, curData.getMaxPrice()/maxNum[2]);  
                input.putScalar(new int[]{i, 3, c}, curData.getMinPrice()/maxNum[3]);  
                input.putScalar(new int[]{i, 4, c}, curData.getTurnover()/maxNum[4]);  
                input.putScalar(new int[]{i, 5, c}, curData.getVolume()/maxNum[5]);  
                //构造label向量  
                label.putScalar(new int[]{i, 0, c}, nextData.getCloseprice()/maxNum[1]);  
                curData = nextData;  
            }  
            if(dataRecord.size()<=0) {  
                break;  
            }  
        }  

        return new DataSet(input, label);  
    }  

    public int batch() {  
        return batchNum;  
    }  

    public int cursor() {  
        return totalExamples() - dataRecord.size();  
    }  

    public int numExamples() {  
        return totalExamples();  
    }  

    public void setPreProcessor(DataSetPreProcessor preProcessor) {  
        throw new UnsupportedOperationException("Not implemented");  
    }  

    public int totalExamples() {  
        return (dataList.size()) / exampleLength;  
    }  

    public int inputColumns() {  
        return dataList.size();  
    }  

    public int totalOutcomes() {  
        return 1;  
    }  

    @Override  
    public List<String> getLabels() {  
        throw new UnsupportedOperationException("Not implemented");  
    }  

    @Override  
    public void remove() {  
        throw new UnsupportedOperationException();  
    }

    @Override
    public boolean resetSupported() {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public boolean asyncSupported() {
        // TODO Auto-generated method stub
        return false;
    }

    @Override
    public DataSetPreProcessor getPreProcessor() {
        // TODO Auto-generated method stub
        return null;
    }  
}  
public class Dtest {
    private static final int IN_NUM = 6;  
    private static final int OUT_NUM = 1;  
    private static final int Epochs = 1;  

    private static final int lstmLayer1Size = 50;  
    private static final int lstmLayer2Size = 100;  

    public static MultiLayerNetwork getNetModel(int nIn,int nOut){  
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()  
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .seed(12345)  
            .l2(0.001)  
            .updater(Updater.RMSPROP)  
            .list()  
            .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(lstmLayer1Size)  
                .activation(Activation.TANH).build())  
            .layer(1, new GravesLSTM.Builder().nIn(lstmLayer1Size).nOut(lstmLayer2Size)  
                .activation(Activation.TANH).build())  
            .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)  
                .nIn(lstmLayer2Size).nOut(nOut).build())  
            .pretrain(false).backprop(true)  
            .build();  

        MultiLayerNetwork net = new MultiLayerNetwork(conf);  
        net.init();  
        net.setListeners(new ScoreIterationListener(1));  

        return net;  
    }  

    public static void train(MultiLayerNetwork net,StockDataIterator iterator){  
        //迭代训练  
        for(int i=0;i<Epochs;i++) {  
            DataSet dataSet = null;  
            while (iterator.hasNext()) {  
                dataSet = iterator.next();  
                net.fit(dataSet);  
            }  
            iterator.reset();  
            System.out.println();  
            System.out.println("=================>完成第"+i+"次完整训练");  
            INDArray initArray = getInitArray(iterator);  

            System.out.println("预测结果:");  
            for(int j=0;j<20;j++) {  
                INDArray output = net.rnnTimeStep(initArray);  
                System.out.print(output.getDouble(0)*iterator.getMaxArr()[1]+" ");  
            }  
            System.out.println();  
            net.rnnClearPreviousState();  
        }  
    }  

    private static INDArray getInitArray(StockDataIterator iter){  
        double[] maxNums = iter.getMaxArr();  
        INDArray initArray = Nd4j.zeros(1, 6, 1);  
        initArray.putScalar(new int[]{0,0,0}, 3433.85/maxNums[0]);  
        initArray.putScalar(new int[]{0,1,0}, 3445.41/maxNums[1]);  
        initArray.putScalar(new int[]{0,2,0}, 3327.81/maxNums[2]);  
        initArray.putScalar(new int[]{0,3,0}, 3470.37/maxNums[3]);  
        initArray.putScalar(new int[]{0,4,0}, 304197903.0/maxNums[4]);  
        initArray.putScalar(new int[]{0,5,0}, 3.8750365e+11/maxNums[5]);  
        return initArray;  
    }  

    public static void main(String[] args) {  
        String inputFile = "sz399905.csv";  
        int batchSize = 1;  
        int exampleLength = 30;  
        //初始化深度神经网络  
        StockDataIterator iterator = new StockDataIterator();  
        iterator.loadData(inputFile,batchSize,exampleLength);  

        MultiLayerNetwork net = getNetModel(IN_NUM,OUT_NUM);  
        train(net, iterator);  
    }  
}

数据格式如下:

sz399905    2015/12/11  7320.16 7290.7  7253.84 7347.36 72132287    1.12E+11    -0.008096367
sz399905    2015/12/10  7374.35 7350.21 7332.98 7437.71 78990424    1.30E+11    -0.003262696
sz399905    2015/12/9   7369.11 7374.27 7322.87 7431.04 83299991    1.32E+11    -0.004034229
sz399905    2015/12/8   7555.46 7404.14 7398.56 7555.46 94938823    1.47E+11    -0.026056828
sz399905    2015/12/7   7526.22 7602.23 7476.19 7602.77 92881296    1.47E+11    0.012055908
sz399905    2015/12/4   7533.61 7511.67 7464.28 7600.34 101362535   1.55E+11    -0.007772264
sz399905    2015/12/3   7413.22 7570.51 7412.65 7571.45 95329412    1.43E+11    0.022232394
sz399905    2015/12/2   7423.5  7405.86 7201.66 7444.22 102647475   1.50E+11    -0.005115571
sz399905    2015/12/1   7403.94 7443.94 7358.37 7519.94 113008679   1.73E+11    0.004797257
sz399905    2015/11/30  7388.28 7408.4  7035.55 7467.47 129234023   1.97E+11    0.004376285
sz399905    2015/11/27  7839.31 7376.12 7317.65 7852    152970489   2.34E+11    -0.063240404
sz399905    2015/11/26  7962.17 7874.08 7859.63 7974.73 140404615   2.29E+11    -0.006096653
sz399905    2015/11/25  7803.29 7922.38 7795.16 7925.54 124435501   2.07E+11    0.015885106
sz399905    2015/11/24  7739.09 7798.5  7635.78 7799.01 110258558   1.69E+11    0.0070143

参考文章:
https://blog.csdn.net/a398942089/article/details/52294082

猜你喜欢

转载自blog.51cto.com/12597095/2119576
今日推荐