deeplearning4j训练推理案例2023——手写数字识别

1.minist数据集

下载链接 6W训练集,1W测试集

2.依赖包

主要是deeplearning4j、javacv的一些包,案例打出的jar包1.3G,pom来自github deeplearning子项目deeplearning4j-examples 的dl4j-examples模块

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.7.9</version>
        <relativePath/>

    </parent>
    <groupId>com.example</groupId>
    <artifactId>demo</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>demo</name>
    <description>demo</description>
    <properties>
        <dl4j-master.version>1.0.0-M2.1</dl4j-master.version>
        <nd4j.backend>nd4j-native</nd4j.backend>
        <java.version>17</java.version>
        <maven-compiler-plugin.version>3.8.1</maven-compiler-plugin.version>
        <maven.minimum.version>3.3.1</maven.minimum.version>
        <exec-maven-plugin.version>1.4.0</exec-maven-plugin.version>
        <maven-shade-plugin.version>2.4.3</maven-shade-plugin.version>
        <jcommon.version>1.0.23</jcommon.version>
        <jfreechart.version>1.0.13</jfreechart.version>
        <logback.version>1.1.7</logback.version>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <junit.version>5.8.0-M1</junit.version>
        <javacv.version>1.5.9</javacv.version>
    </properties>
    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.bytedeco</groupId>
                <artifactId>javacv-platform</artifactId>
                <version>${javacv.version}</version>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter</artifactId>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>${nd4j.backend}</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>


        <dependency>
            <groupId>org.datavec</groupId>
            <artifactId>datavec-api</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <dependency>
            <groupId>org.datavec</groupId>
            <artifactId>datavec-data-image</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <dependency>
            <groupId>org.datavec</groupId>
            <artifactId>datavec-local</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-datasets</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>

        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>resources</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>

        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-ui</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-zoo</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <!-- ParallelWrapper & ParallelInference live here -->
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-parallel-wrapper</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <!-- Used in the feedforward/classification/MLP* and feedforward/regression/RegressionMathFunctions example -->
        <dependency>
            <groupId>jfree</groupId>
            <artifactId>jfreechart</artifactId>
            <version>${jfreechart.version}</version>
        </dependency>
        <dependency>
            <groupId>org.jfree</groupId>
            <artifactId>jcommon</artifactId>
            <version>${jcommon.version}</version>
        </dependency>
        <!-- Used for downloading data in some of the examples -->
        <dependency>
            <groupId>org.apache.httpcomponents</groupId>
            <artifactId>httpclient</artifactId>
            <version>4.3.5</version>
        </dependency>
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>${logback.version}</version>
        </dependency>

        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>javacv-platform</artifactId>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-api</artifactId>
            <version>1.0.0-M2.1</version>
        </dependency>

    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>17</source>
                    <target>17</target>
                </configuration>
            </plugin>
        </plugins>
    </build>

</project>

3.手写数字训练与推理

1个epoch训练耗时100s,准确率达97%,详见代码注释,框架的api做得还比较好用

package ai;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.common.io.Assert;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;

import java.io.File;
import java.util.Random;

@Slf4j
public class LeNetMNISTReLu {
    
    
    private static final String DATASET_PATH_BASE = "D:\\";

    public static void main(String[] args) throws Exception {
    
    
        int height = 28;
        int width = 28;
        // 黑白图片通道只有一个
        int channels = 1;
        // 0-9十种数字
        int outputNum = 10;
        int batchSize = 64;
        // 这里一个epoch耗时约100s,3次准确率99%
        int nEpochs = 1;


        Assert.isTrue(new File(DATASET_PATH_BASE + "/mnist_png").exists(), "请下载压缩包并解压到" + DATASET_PATH_BASE);
        // 该label生成器会将数据所在父目录名作为label,要求目录名必须为数值,这里mnist数据集正好是放在0-9文件夹的
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        // 归一化(0-1)
        DataNormalization normalization = new ImagePreProcessingScaler();
        Random random = new Random(12345);
        log.info("训练集6W张...");
        File trainData = new File(DATASET_PATH_BASE + "/mnist_png/training");
        FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, random);
        ImageRecordReader trainRecordReader = new ImageRecordReader(height, width, channels, labelMaker);
        trainRecordReader.initialize(trainSplit);
        DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, outputNum);
        normalization.fit(trainIter);
        trainIter.setPreProcessor(normalization); // 先像素归一化

        log.info("验证集1W张...");
        File validateData = new File(DATASET_PATH_BASE + "/mnist_png/testing");
        FileSplit validateSplit = new FileSplit(validateData, NativeImageLoader.ALLOWED_FORMATS, random);
        ImageRecordReader validateRecordReader = new ImageRecordReader(height, width, channels, labelMaker);
        validateRecordReader.initialize(validateSplit);
        DataSetIterator validateIter = new RecordReaderDataSetIterator(validateRecordReader, batchSize, 1, outputNum);
        validateIter.setPreProcessor(normalization);

        // 训练集6W数据 每次迭代batchSize=64,故这里大概有1000次迭代
        // 学习率,每200个迭代更新一次学习率(步长),先大一点,还可以每个Epoch更新一次学习率
        MapSchedule mapSchedule = new MapSchedule.Builder(ScheduleType.ITERATION)
                .add(0, 0.06)
                .add(200, 0.05)
                .add(600, 0.028)
                .add(800, 0.006)
                .add(1000, 0.001)
                .build();

        // 超参
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(1)
                .l2(0.0005)
                .updater(new Nesterovs(mapSchedule))
                //.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) //该优化器导致长时间无法拟合
                .weightInit(WeightInit.XAVIER)
                .list()
                .layer(new ConvolutionLayer.Builder(5, 5)
                        .nIn(channels)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(new DenseLayer.Builder().activation(Activation.RELU)
                        .nOut(500)
                        .build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
                .build();

        // 神经网络对象构建
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        // 训练监控,每次迭代打印损失函数值
        net.setListeners(new ScoreIterationListener(10));
        // WEB UI监控训练过程
        //UIServer uiServer = UIServer.getInstance();
        //FileStatsStorage statsStorage = new FileStatsStorage(new File("D:\\ai-webui.dat"));
        //uiServer.attach(statsStorage);
        //net.setListeners(new StatsListener(statsStorage));
        log.info("网络参数个数{}", net.numParams());
        long startTime = System.currentTimeMillis();
        // 训练epochs轮
        for (int i = 0; i < nEpochs; i++) {
    
    
            log.info("Epoch=" + i);
            net.fit(trainIter);
            Evaluation eval = net.evaluate(validateIter);
            log.info(eval.stats());
            trainIter.reset();
            validateIter.reset();
        }
        log.info("训练耗时{}毫秒", System.currentTimeMillis() - startTime);
        // 保存模型
        File ministModelPath = new File(DATASET_PATH_BASE + "/ministModel.zip");
        ModelSerializer.writeModel(net, ministModelPath, true);
        // 推理逻辑:加载网络(模型)——>加载测试图片——>预测
        MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(new File(DATASET_PATH_BASE + "/ministModel.zip"));
        NativeImageLoader imageLoader = new NativeImageLoader(height, width, channels);
        FileUtils.listFiles(new File("D:\\mnist_png\\testing"), null, true)
                .parallelStream().forEach(file -> {
    
    
                    try {
    
    
                        INDArray matrix = imageLoader.asMatrix(file);
                        INDArray output = network.output(matrix);
                        // 取最可能的预测结果
                        int predictedValue = Nd4j.argMax(output, 1).getInt(0);
                        // 数字图片按数值放在每个文件夹的,故图片所在文件夹名即为真实值
                        String realValue = file.getParentFile().getName();
                        log.info("真实值:{},预测值:{}", realValue, predictedValue);
                        Assert.isTrue(predictedValue == Integer.parseInt(realValue), file.getAbsolutePath() + "预测错误");
                    } catch (Exception e) {
    
    
                        log.warn(e.getMessage(), e);
                    }
                });
    }
}

4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples

deeplearning4j-examples 参考其readme文档,这里挑几个案例跑一跑,先把dl4j-examples里的util的4个工具类拿下来
①经典鸢尾花分类案例,直接跑就可以
IrisClassifier.java

/* *****************************************************************************
 *
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package ai;

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;

/**
 * @author Adam Gibson
 */
@SuppressWarnings("DuplicatedCode")
public class IrisClassifier {
    
    

    private static Logger log = LoggerFactory.getLogger(IrisClassifier.class);

    public static void main(String[] args) throws  Exception {
    
    

        //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
        int numLinesToSkip = 0;
        char delimiter = ',';
        RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
        recordReader.initialize(new FileSplit(new File(DownloaderUtility.IRISDATA.Download(),"iris.txt")));

        //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
        int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
        int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
        int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

        DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
        DataSet allData = iterator.next();
        allData.shuffle();
        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

        DataSet trainingData = testAndTrain.getTrain();
        DataSet testData = testAndTrain.getTest();

        //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
        DataNormalization normalizer = new NormalizerStandardize();
        normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
        normalizer.transform(trainingData);     //Apply normalization to the training data
        normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set


        final int numInputs = 4;
        int outputNum = 3;
        long seed = 6;


        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .updater(new Sgd(0.1))
                .l2(1e-4)
                .list()
                .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
                        .build())
                .layer(new DenseLayer.Builder().nIn(3).nOut(3)
                        .build())
                .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
                        .nIn(3).nOut(outputNum).build())
                .build();

        //run the model
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        //record score once every 100 iterations
        model.setListeners(new ScoreIterationListener(100));

        for(int i=0; i<1000; i++ ) {
    
    
            model.fit(trainingData);
        }

        //evaluate the model on the test set
        Evaluation eval = new Evaluation(3);
        INDArray output = model.output(testData.getFeatures());
        eval.eval(testData.getLabels(), output);
        log.info(eval.stats());

    }

}

② 较简单一点的MNIST分类
MNISTSingleLayer.java,注意将batchSize,epoch改小一点,否则小霸王运行比较耗时

/* *****************************************************************************
 *
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package ai;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**A Simple Multi Layered Perceptron (MLP) applied to digit classification for
 * the MNIST Dataset (http://yann.lecun.com/exdb/mnist/).
 *
 * This file builds one input layer and one hidden layer.
 *
 * The input layer has input dimension of numRows*numColumns where these variables indicate the
 * number of vertical and horizontal pixels in the image. This layer uses a rectified linear unit
 * (relu) activation function. The weights for this layer are initialized by using Xavier initialization
 * (https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/)
 * to avoid having a steep learning curve. This layer will have 1000 output signals to the hidden layer.
 *
 * The hidden layer has input dimensions of 1000. These are fed from the input layer. The weights
 * for this layer is also initialized using Xavier initialization. The activation function for this
 * layer is a softmax, which normalizes all the 10 outputs such that the normalized sums
 * add up to 1. The highest of these normalized values is picked as the predicted class.
 *
 */
public class MNISTSingleLayer {
    
    

    private static Logger log = LoggerFactory.getLogger(MNISTSingleLayer.class);

    public static void main(String[] args) throws Exception {
    
    
        //number of rows and columns in the input pictures
        final int numRows = 28;
        final int numColumns = 28;
        int outputNum = 10; // number of output classes
        int batchSize = 64; // batch size for each epoch
        int rngSeed = 123; // random number seed for reproducibility
        int numEpochs = 1; // number of epochs to perform

        //Get the DataSetIterators:
        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
        DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);


        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(rngSeed) //include a random seed for reproducibility
                // use stochastic gradient descent as an optimization algorithm
                .updater(new Nesterovs(0.006, 0.9))
                .l2(1e-4)
                .list()
                .layer(new DenseLayer.Builder() //create the first, input layer with xavier initialization
                        .nIn(numRows * numColumns)
                        .nOut(1000)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
                        .nIn(1000)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        //print the score with every 1 iteration
        model.setListeners(new ScoreIterationListener(1));

        log.info("Train model....");
        model.fit(mnistTrain, numEpochs);


        log.info("Evaluate model....");
        Evaluation eval = model.evaluate(mnistTest);
        log.info(eval.stats());
        log.info("****************Example finished********************");

    }

}

③使用 RNN/CNN 模型对 IMDB 数据集进行情感分类

github代码路径deeplearning4j-examples/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/textclassification/pretrainedword2vec下的三个java文件都搞下来即可。

  • 下载一个imdb数据集,建议手动下载后放ImdbReviewClassificationRNN.java DATA_PATH文件夹即可
  • 词向量文件手动下载后将路径赋值给wordVectorsPath变量,16G小霸王加载词向量就内存溢出了

猜你喜欢

转载自blog.csdn.net/qq_39506978/article/details/134024627