笔记(1)中利用tensorflow.js完成了机器学习中曲线拟合的任务,这篇笔记将实现一个经典的机器学习问题——CNN识别手写数字集MNIST。参考官方示例Training on Images: Recognizing Handwritten Digits with a Convolutional Neural Network,修改部分代码并用echarts改写vega。
1、定义mnist数据类
import * as tf from '@tensorflow/tfjs'; const IMAGE_SIZE = 784;//图片大小28*28 const NUM_CLASSES = 10;//类别数 const NUM_DATASET_ELEMENTS = 65000;//总样本数 const NUM_TRAIN_ELEMENTS = 55000;//训练样本数 const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;//测试样本数 const MNIST_IMAGES_SPRITE_PATH = './src/mnist_images.png';//mnist图像 const MNIST_LABELS_PATH = './src/mnist_labels_uint8';//mnist图像对应的类别 export class MnistData { constructor() { this.shuffledTrainIndex = 0; this.shuffledTestIndex = 0; } async load() { const img = new Image(); const canvas = document.createElement('canvas'); const ctx = canvas.getContext('2d'); const imgRequest = new Promise((resolve, reject) => { img.crossOrigin = ''; img.onload = () => { img.width = img.naturalWidth; img.height = img.naturalHeight; const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); const chunkSize = 5000; canvas.width = img.width; canvas.height = chunkSize; for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { const datasetBytesView = new Float32Array( datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, IMAGE_SIZE * chunkSize ); ctx.drawImage(img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize); const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); for (let j = 0; j < imageData.data.length / 4; j++) { datasetBytesView[j] = imageData.data[j * 4] / 255; } } this.datasetImages = new Float32Array(datasetBytesBuffer); resolve(); }; img.src = MNIST_IMAGES_SPRITE_PATH; }); const labelsRequest = fetch(MNIST_LABELS_PATH); const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]); this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS); this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS); this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); } nextTrainBatch(batchSize) { return this.nextBatch( batchSize, [this.trainImages, this.trainLabels], () => { this.shuffledTrainIndex = (this.shuffledTrainIndex + 1) % this.trainIndices.length; return this.trainIndices[this.shuffledTrainIndex]; } ); } nextTestBatch(batchSize) { return this.nextBatch( batchSize, [this.testImages, this.testLabels], () => { this.shuffledTestIndex = (this.shuffledTestIndex + 1) % this.testIndices.length; return this.testIndices[this.shuffledTestIndex]; } ); } nextBatch(batchSize, data, index) { const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE); const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES); for (let i = 0; i < batchSize; i++) { const idx = index(); const image = data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE); batchImagesArray.set(image, i * IMAGE_SIZE); const label = data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES); batchLabelsArray.set(label, i * NUM_CLASSES); } const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]); const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]); return {xs, labels}; } }
mnist数据类在构造器内声明两个index,分别是训练过程的洗牌index和测试过程的洗牌index。引入洗牌index是为了防止模型训练受到传入图像顺序的影响。假设不洗牌,先将所有1的图像传入模型进行训练,那么此时训练的模型将学会预测1的手写体;之后传入所有2的图像,则模型将切换到仅预测2(这样会最小化损失函数);则这样永远无法完整的对全部数据集进行预测。
之后定义方法load(),该方法将图片mnist_images.png进行切割,并从mnist_labels_uint8中找到对应的label。之后的nextTrainBatch和nextTestBatch将分别返回训练样本和测试样本。
2、CNN的构建
import * as tf from '@tensorflow/tfjs'; import {MnistData} from './data'; import * as ui from './ui'; const model = tf.sequential(); model.add(tf.layers.conv2d({ inputShape: [28, 28, 1], kernelSize: 5, filters: 8, strides: 1, activation: 'relu', kernelInitializer: 'varianceScaling' })); model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })); model.add(tf.layers.conv2d({ kernelSize: 5, filters: 16, strides: 1, activation: 'relu', kernelInitializer: 'varianceScaling' })); model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })); model.add(tf.layers.flatten()); model.add(tf.layers.dense({ units: 10, kernelInitializer: 'varianceScaling', activation: 'softmax' })); const LEARNING_RATE = 0.15; const optimizer = tf.train.sgd(LEARNING_RATE); model.compile({ optimizer: optimizer, loss: 'categoricalCrossentropy', metrics: ['accuracy'] });
首先定义模型为tf.sequential(),该模型中张量将连续地从一层传递到下一层。之后分别加入卷积层、池化层、卷积层、池化层、flatten层(将输入展为向量)和dense层(完全连接层)。
其中卷积层是2维卷积;inputShape是传入数据的维数(第二个卷积层可以不指定inputShape,tf将从前一层的输出推断该值),三个值分别为行、列、深度,mnist图像格式为28*28像素点,深度为1,因为只有一个颜色通道;kernelSize是应用到输入数据上的滑动滤波器窗口的大小,5代表5*5矩形卷积窗口;filters指滤波器窗口的数量,8代表有8个滤波器;strides指滑动窗口的步长,1代表每次将以1像素为单位滑动滤波器;activation为激活函数,relu代表线性整流函数,函数形状如下所示;kernelInitializer用于随机初始化模型权重,这里使用VarianceScaling方法初始化模型。
池化层使用的是二维最大池,通过计算该层每个滑动窗口的最大值来降低纬度。其中poolSize是滑动窗口的大小,[2, 2]代表2*2的矩形窗口;strides代表滑动窗口移动的步长,[2, 2]代表窗口将在水平和垂直方向是以2像素为单位进行移动。
flatten层将上一层的输出平铺到一个矢量上。dense层(完全连接层)将执行最终的分类任务。其中units是输出的激活数,10代表将有10种不同的输出,满足mnist的10种分类(数字0-9);kernelInitializer设为VarianceScaling初始化方法;分类任务的最后一层激活函数activation通常设为softmax,该函数将10维输出向量归一化为概率分布,以便我们知道该样本属于10个类中每个类的概率。
定义学习率为0.15,优化器为随机梯度下降法(SGD);损失函数为categoricalCrossentropy,即分类任务的交叉熵;评价指标为准确率accuracy,即所有预测中正确预测的百分比。之后编译模型。
3、模型训练
const BATCH_SIZE = 64; const TRAIN_BATCHES = 150; const TEST_BATCH_SIZE = 1000; const TEST_ITERATION_FREQUENCY = 5; async function train() { ui.isTraining(); const lossValues = []; const accuracyValues = []; for (let i = 0; i < TRAIN_BATCHES; i++) { const batch = data.nextTrainBatch(BATCH_SIZE); let testBatch; let validationData; if (i % TEST_ITERATION_FREQUENCY === 0) { testBatch = data.nextTestBatch(TEST_BATCH_SIZE); validationData = [ testBatch.xs.reshape([TEST_BATCH_SIZE, 28, 28, 1]), testBatch.labels ]; } const history = await model.fit( batch.xs.reshape([BATCH_SIZE, 28, 28, 1]), batch.labels, {batchSize: BATCH_SIZE, validationData, epochs: 1} ); const loss = history.history.loss[0]; const accuracy = history.history.acc[0]; lossValues.push([i, loss]); if (testBatch != null) { accuracyValues.push([i, accuracy]); } batch.xs.dispose(); batch.labels.dispose(); if (testBatch != null) { testBatch.xs.dispose(); testBatch.labels.dispose(); } await tf.nextFrame(); } ui.plot(lossValues, accuracyValues); } async function showPredictions() { const testExamples = 100; const batch = data.nextTestBatch(testExamples); tf.tidy(() => { const output = model.predict(batch.xs.reshape([-1, 28, 28, 1])); const axis = 1; const labels = Array.from(batch.labels.argMax(axis).dataSync()); const predictions = Array.from(output.argMax(axis).dataSync()); ui.showTestResults(batch, predictions, labels); }); } let data; async function load() { data = new MnistData(); await data.load(); } async function mnist() { await load(); await train(); showPredictions(); } mnist();
训练时首先调用ui.isTraining()将training输出到文档,之后进行TEST_BATCH_SIZE = 1000轮迭代,每次记录损失函数的值及准确率,并在训练结束后进行可视化展示。之后调用showPredictions方法将测试样本的预测值与实际像素图输出到文档。
4、可视化部分
const echarts = require('echarts'); const statusElement = document.getElementById('status'); const imagesElement = document.getElementById('images'); export function isTraining() { statusElement.innerText = 'Training...'; } export function showTestResults(batch, predictions, labels) { statusElement.innerText = 'Testing...'; const testExamples = batch.xs.shape[0]; let totalCorrect = 0; for (let i = 0; i < testExamples; i++) { const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]); const div = document.createElement('div'); div.className = 'pred-container'; const canvas = document.createElement('canvas'); canvas.className = 'prediction-canvas'; draw(image.flatten(), canvas); const pred = document.createElement('div'); const prediction = predictions[i]; const label = labels[i]; const correct = prediction === label; pred.className = `pred ${(correct ? 'pred-correct' : 'pred-incorrect')}`; pred.innerText = `pred: ${prediction}`; div.appendChild(pred); div.appendChild(canvas); imagesElement.appendChild(div); } } const lossChart = echarts.init(document.getElementById('lossChart')); const accuracyChart = echarts.init(document.getElementById('accuracyChart')); export function plot(lossValues, accuracyValues) { lossChart.setOption({ title: { text: 'Loss Values' }, xAxis: { type: 'value' }, yAxis: { type: 'value' }, series: [{ name: 'loss', type: 'line', data: lossValues }] }); accuracyChart.setOption({ title: { text: 'Accuracy Values' }, xAxis: { type: 'value' }, yAxis: { type: 'value' }, series: [{ name: 'accuracy', type: 'line', data: accuracyValues }] }); } export function draw(image, canvas) { const [width, height] = [28, 28]; canvas.width = width; canvas.height = height; const ctx = canvas.getContext('2d'); const imageData = new ImageData(width, height); const data = image.dataSync(); for (let i = 0; i < height * width; ++i) { const j = i * 4; imageData.data[j + 0] = data[i] * 255; imageData.data[j + 1] = data[i] * 255; imageData.data[j + 2] = data[i] * 255; imageData.data[j + 3] = 255; } ctx.putImageData(imageData, 0, 0); }
BUG1:
lossChart.setOption({ series: [{ name: 'loss', type: 'line', data: lossValues }] });
当绘图时仅传入series参数时,Echarts报错:
Uncaught (in promise) TypeError: Cannot read property 'get' of undefined
此时将xAxis、yAxis等添加进options中即可。
lossChart.setOption({ title: { text: 'Loss Values' }, xAxis: { type: 'value' }, yAxis: { type: 'value' }, series: [{ name: 'loss', type: 'line', data: lossValues }] });
BUG2:
报错 Uncaught Error: Unsupported core optimizer type: t
原因是相关依赖未正确安装,cd到package.json同目录,运行yarn命令安装相关包。
完整程序见我的github,具体步骤为:
step1 新建文件夹,cmd输入git clone [email protected]:orangecsy/tfjs-exercise.git,cd 2进入文件夹2;
step2 cmd输入webpack,打包;
step3 cd dist进入dist文件夹,cmd中输入http-server(需先npm install http-server)或使用webpack配置开发服务器;
step4 浏览器中输入http://127.0.0.1:8080/,即为结果。