1.minist数据集
下载链接 6W训练集,1W测试集
2.依赖包
主要是deeplearning4j、javacv的一些包,案例打出的jar包1.3G,pom来自github deeplearning子项目deeplearning4j-examples 的dl4j-examples模块
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.7.9</version>
<relativePath/>
</parent>
<groupId>com.example</groupId>
<artifactId>demo</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>demo</name>
<description>demo</description>
<properties>
<dl4j-master.version>1.0.0-M2.1</dl4j-master.version>
<nd4j.backend>nd4j-native</nd4j.backend>
<java.version>17</java.version>
<maven-compiler-plugin.version>3.8.1</maven-compiler-plugin.version>
<maven.minimum.version>3.3.1</maven.minimum.version>
<exec-maven-plugin.version>1.4.0</exec-maven-plugin.version>
<maven-shade-plugin.version>2.4.3</maven-shade-plugin.version>
<jcommon.version>1.0.23</jcommon.version>
<jfreechart.version>1.0.13</jfreechart.version>
<logback.version>1.1.7</logback.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<junit.version>5.8.0-M1</junit.version>
<javacv.version>1.5.9</javacv.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacv-platform</artifactId>
<version>${javacv.version}</version>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-local</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-datasets</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>resources</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<!-- ParallelWrapper & ParallelInference live here -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-parallel-wrapper</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<!-- Used in the feedforward/classification/MLP* and feedforward/regression/RegressionMathFunctions example -->
<dependency>
<groupId>jfree</groupId>
<artifactId>jfreechart</artifactId>
<version>${jfreechart.version}</version>
</dependency>
<dependency>
<groupId>org.jfree</groupId>
<artifactId>jcommon</artifactId>
<version>${jcommon.version}</version>
</dependency>
<!-- Used for downloading data in some of the examples -->
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>4.3.5</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacv-platform</artifactId>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>17</source>
<target>17</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
3.手写数字训练与推理
1个epoch训练耗时100s,准确率达97%,详见代码注释,框架的api做得还比较好用
package ai;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.common.io.Assert;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.io.File;
import java.util.Random;
@Slf4j
public class LeNetMNISTReLu {
private static final String DATASET_PATH_BASE = "D:\\";
public static void main(String[] args) throws Exception {
int height = 28;
int width = 28;
// 黑白图片通道只有一个
int channels = 1;
// 0-9十种数字
int outputNum = 10;
int batchSize = 64;
// 这里一个epoch耗时约100s,3次准确率99%
int nEpochs = 1;
Assert.isTrue(new File(DATASET_PATH_BASE + "/mnist_png").exists(), "请下载压缩包并解压到" + DATASET_PATH_BASE);
// 该label生成器会将数据所在父目录名作为label,要求目录名必须为数值,这里mnist数据集正好是放在0-9文件夹的
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
// 归一化(0-1)
DataNormalization normalization = new ImagePreProcessingScaler();
Random random = new Random(12345);
log.info("训练集6W张...");
File trainData = new File(DATASET_PATH_BASE + "/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, random);
ImageRecordReader trainRecordReader = new ImageRecordReader(height, width, channels, labelMaker);
trainRecordReader.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, outputNum);
normalization.fit(trainIter);
trainIter.setPreProcessor(normalization); // 先像素归一化
log.info("验证集1W张...");
File validateData = new File(DATASET_PATH_BASE + "/mnist_png/testing");
FileSplit validateSplit = new FileSplit(validateData, NativeImageLoader.ALLOWED_FORMATS, random);
ImageRecordReader validateRecordReader = new ImageRecordReader(height, width, channels, labelMaker);
validateRecordReader.initialize(validateSplit);
DataSetIterator validateIter = new RecordReaderDataSetIterator(validateRecordReader, batchSize, 1, outputNum);
validateIter.setPreProcessor(normalization);
// 训练集6W数据 每次迭代batchSize=64,故这里大概有1000次迭代
// 学习率,每200个迭代更新一次学习率(步长),先大一点,还可以每个Epoch更新一次学习率
MapSchedule mapSchedule = new MapSchedule.Builder(ScheduleType.ITERATION)
.add(0, 0.06)
.add(200, 0.05)
.add(600, 0.028)
.add(800, 0.006)
.add(1000, 0.001)
.build();
// 超参
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(1)
.l2(0.0005)
.updater(new Nesterovs(mapSchedule))
//.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) //该优化器导致长时间无法拟合
.weightInit(WeightInit.XAVIER)
.list()
.layer(new ConvolutionLayer.Builder(5, 5)
.nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new ConvolutionLayer.Builder(5, 5)
.stride(1, 1)
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
.build();
// 神经网络对象构建
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
// 训练监控,每次迭代打印损失函数值
net.setListeners(new ScoreIterationListener(10));
// WEB UI监控训练过程
//UIServer uiServer = UIServer.getInstance();
//FileStatsStorage statsStorage = new FileStatsStorage(new File("D:\\ai-webui.dat"));
//uiServer.attach(statsStorage);
//net.setListeners(new StatsListener(statsStorage));
log.info("网络参数个数{}", net.numParams());
long startTime = System.currentTimeMillis();
// 训练epochs轮
for (int i = 0; i < nEpochs; i++) {
log.info("Epoch=" + i);
net.fit(trainIter);
Evaluation eval = net.evaluate(validateIter);
log.info(eval.stats());
trainIter.reset();
validateIter.reset();
}
log.info("训练耗时{}毫秒", System.currentTimeMillis() - startTime);
// 保存模型
File ministModelPath = new File(DATASET_PATH_BASE + "/ministModel.zip");
ModelSerializer.writeModel(net, ministModelPath, true);
// 推理逻辑:加载网络(模型)——>加载测试图片——>预测
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(new File(DATASET_PATH_BASE + "/ministModel.zip"));
NativeImageLoader imageLoader = new NativeImageLoader(height, width, channels);
FileUtils.listFiles(new File("D:\\mnist_png\\testing"), null, true)
.parallelStream().forEach(file -> {
try {
INDArray matrix = imageLoader.asMatrix(file);
INDArray output = network.output(matrix);
// 取最可能的预测结果
int predictedValue = Nd4j.argMax(output, 1).getInt(0);
// 数字图片按数值放在每个文件夹的,故图片所在文件夹名即为真实值
String realValue = file.getParentFile().getName();
log.info("真实值:{},预测值:{}", realValue, predictedValue);
Assert.isTrue(predictedValue == Integer.parseInt(realValue), file.getAbsolutePath() + "预测错误");
} catch (Exception e) {
log.warn(e.getMessage(), e);
}
});
}
}
4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples
deeplearning4j-examples 参考其readme文档,这里挑几个案例跑一跑,先把dl4j-examples里的util的4个工具类拿下来
①经典鸢尾花分类案例,直接跑就可以
IrisClassifier.java
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package ai;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
/**
* @author Adam Gibson
*/
@SuppressWarnings("DuplicatedCode")
public class IrisClassifier {
private static Logger log = LoggerFactory.getLogger(IrisClassifier.class);
public static void main(String[] args) throws Exception {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new File(DownloaderUtility.IRISDATA.Download(),"iris.txt")));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 150; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
final int numInputs = 4;
int outputNum = 3;
long seed = 6;
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(0.1))
.l2(1e-4)
.list()
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
.build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3)
.build())
.layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
.nIn(3).nOut(outputNum).build())
.build();
//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//record score once every 100 iterations
model.setListeners(new ScoreIterationListener(100));
for(int i=0; i<1000; i++ ) {
model.fit(trainingData);
}
//evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());
eval.eval(testData.getLabels(), output);
log.info(eval.stats());
}
}
② 较简单一点的MNIST分类
MNISTSingleLayer.java
,注意将batchSize,epoch改小一点,否则小霸王运行比较耗时
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package ai;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**A Simple Multi Layered Perceptron (MLP) applied to digit classification for
* the MNIST Dataset (http://yann.lecun.com/exdb/mnist/).
*
* This file builds one input layer and one hidden layer.
*
* The input layer has input dimension of numRows*numColumns where these variables indicate the
* number of vertical and horizontal pixels in the image. This layer uses a rectified linear unit
* (relu) activation function. The weights for this layer are initialized by using Xavier initialization
* (https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/)
* to avoid having a steep learning curve. This layer will have 1000 output signals to the hidden layer.
*
* The hidden layer has input dimensions of 1000. These are fed from the input layer. The weights
* for this layer is also initialized using Xavier initialization. The activation function for this
* layer is a softmax, which normalizes all the 10 outputs such that the normalized sums
* add up to 1. The highest of these normalized values is picked as the predicted class.
*
*/
public class MNISTSingleLayer {
private static Logger log = LoggerFactory.getLogger(MNISTSingleLayer.class);
public static void main(String[] args) throws Exception {
//number of rows and columns in the input pictures
final int numRows = 28;
final int numColumns = 28;
int outputNum = 10; // number of output classes
int batchSize = 64; // batch size for each epoch
int rngSeed = 123; // random number seed for reproducibility
int numEpochs = 1; // number of epochs to perform
//Get the DataSetIterators:
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed) //include a random seed for reproducibility
// use stochastic gradient descent as an optimization algorithm
.updater(new Nesterovs(0.006, 0.9))
.l2(1e-4)
.list()
.layer(new DenseLayer.Builder() //create the first, input layer with xavier initialization
.nIn(numRows * numColumns)
.nOut(1000)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build())
.layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
.nIn(1000)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//print the score with every 1 iteration
model.setListeners(new ScoreIterationListener(1));
log.info("Train model....");
model.fit(mnistTrain, numEpochs);
log.info("Evaluate model....");
Evaluation eval = model.evaluate(mnistTest);
log.info(eval.stats());
log.info("****************Example finished********************");
}
}
③使用 RNN/CNN 模型对 IMDB 数据集进行情感分类
github代码路径deeplearning4j-examples/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/textclassification/pretrainedword2vec
下的三个java文件都搞下来即可。
- 下载一个imdb数据集,建议手动下载后放ImdbReviewClassificationRNN.java
DATA_PATH
文件夹即可 - 词向量文件手动下载后将路径赋值给
wordVectorsPath变量
,16G小霸王加载词向量就内存溢出了