【tensorflow.js学习笔记(1)】tf.js环境搭建及曲线拟合例子

月初TensorFlow开发者大会上,谷歌正式发布了TensorFlow的JS版本tensorflow.js,并演示了几个很有意思的demo,展现了浏览器环境下也能进行深度学习任务的能力。tensorflowjs利用WebGl加速,在浏览器环境下训练、部署机器学习模型。下面我尝试引入tensorflow.js并运行一个曲线拟合的例子。

1、文件形式引入

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"></script>
<script>
  const model = tf.sequential();
  model.add(tf.layers.dense({units: 1, inputShape: [1]}));
  model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});

  const x = tf.tensor2d([1, 2, 3, 4], [4, 1]);
  const y = tf.tensor2d([1, 3, 5, 7], [4, 1]);

  model.fit(x, y).then(() => {
    model.predict(tf.tensor2d([5], [1, 1])).print();
  });
</script>

首先调用tf.sequential()构建模型,损失函数为均方差,优化器为sgd(梯度下降)。待拟合的点序列为(1,1),(2,3),(3,5),(4,7),训练模型,输入x=5。

打开浏览器,输出为:

Tensor
     [[8.1529675],]

2、使用webpack

npm install @tensorflow/tfjs

首先利用npm安装tensorflow.js(也可用yarn),新建index.js文件,内容如下。

import * as tf from '@tensorflow/tfjs';

const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});

const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);

model.fit(xs, ys).then(() => {
  model.predict(tf.tensor2d([5], [1, 1])).print();
});

利用webpack可使用import语法引入tf.js。配置webpack.config.js文件如下。

const path = require('path');

module.exports={
    //入口文件的配置项
    entry:{
      entry: './index.js'
    },
    //出口文件的配置项
    output:{
      path: path.resolve(__dirname, 'dist'),
      filename: 'bundle.js'
    }
}

运行webpack命令,将在目录下生成dist文件夹。cd进入该文件夹,用node运行bundle.js文件,输出结果。

3、曲线拟合

参考Fitting a Curve to Synthetic Data,这是TensorFlow官方关于曲线拟合的例子,其中使用Vega进行可视化展示。下面我将用Echarts替换Vega进行可视化展示并重写部分程序,代码结构如下所示。


其中index.js为入口文件,dist文件夹下为分发文件,webpack.config.js的内容如下。

const path = require('path');

module.exports={
  mode: 'development',
  //入口文件的配置项
  entry:{
    entry: './src/index.js'
  },
  //出口文件的配置项
  output:{
    path: path.resolve(__dirname, 'dist'),
    filename: 'bundle.js'
  },
  //控制台报错信息
  devtool: 'inline-source-map'
}

index.html内容如下。

<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Document</title>
  <style>
    #chart {
      width: 800px;
      height: 800px;
    }
  </style>
</head>
<body>
  <div id="chart"></div>
  <script src="bundle.js"></script>
</body>
</html>

入口文件index.js内容如下。

import * as tf from '@tensorflow/tfjs';
var echarts = require('echarts');

const a = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
const c = tf.variable(tf.scalar(Math.random()));
const d = tf.variable(tf.scalar(Math.random()));

function predict(x) {
  return tf.tidy(() => {
    return a.mul(x.pow(tf.scalar(3, 'int32'))) 
      .add(b.mul(x.square()))
      .add(c.mul(x))
      .add(d);
  });
}

function loss(prediction, labels) {
  const error = prediction.sub(labels).square().mean();
  return error;
}

const numIterations = 75;
const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);

async function train(xs, ys, numIterations) {
  for (let iter = 0; iter < numIterations; iter++) {
    optimizer.minimize(() => {
      const pred = predict(xs);
      return loss(pred, ys);
    });
    await tf.nextFrame();
  }
}

function generateData(numPoints, coeff, sigma = 0.04) {
  return tf.tidy(() => {
    const [a, b, c, d] = [
      tf.scalar(coeff.a),
      tf.scalar(coeff.b),
      tf.scalar(coeff.c),
      tf.scalar(coeff.d)
    ];
  
    const xs = tf.randomUniform([numPoints], -1, 1);
    const ys = a.mul(xs.pow(tf.scalar(3, 'int32')))
      .add(b.mul(xs.square()))
      .add(c.mul(xs))
      .add(d)
      .add(tf.randomNormal([numPoints], 0, sigma));

    const ymin = ys.min();
    const ymax = ys.max();
    const yrange = ymax.sub(ymin);
    const ysNormalized = ys.sub(ymin).div(yrange);

    return {
      xs,
      ys: ysNormalized
    };
  })
}

async function plotData(xs, ys, preds) {
  const xvals = await xs.data();
  const yvals = await ys.data();
  const predVals = await preds.data();
  
  const valuesBefore = Array.from(xvals).map((x, i) => {
    return [xvals[i], yvals[i]];
  });
  const valuesAfter= Array.from(xvals).map((x, i) => {
    return [xvals[i], predVals[i]];
  });
  // 二维数组排序
  valuesAfter.sort(function(x, y) {
    return x[0] - y[0];
  });
  curveChart.setOption({
    xAxis: {
      min: -1,
      max: 1
    },
    yAxis: {
      min: 0,
      max: 1
    },
    series: [{
      symbolSize: 12,
      data: valuesBefore,
      type: 'scatter'
    },{
      data: valuesAfter,
      encode: {
        x: 0,
        y: 1
      },
      type: 'line'
    }]
  });
}

async function learnCoefficients() {
  const trueCoefficients = {a: -0.8, b: -0.2, c: 0.9, d: 0.5};
  // 生成有误差的训练数据
  const trainingData = generateData(100, trueCoefficients);
  // 训练模型
  await train(trainingData.xs, trainingData.ys, numIterations);
  // 预测数据
  const predictionsAfter = predict(trainingData.xs);
  // 绘制散点图及拟合曲线
  await plotData(trainingData.xs, trainingData.ys, predictionsAfter);
  predictionsAfter.dispose();
}


const curveChart = echarts.init(document.getElementById('chart'));
learnCoefficients();

首先引入tensorflow.js及echarts,之后定义4个参数a、b、c、d,分别是待拟合曲线y=a*x^3+b*x^2+c*x+d的四个参数,初始设为随机值。

定义函数predict,传入x,返回拟合后的估计值y。函数loss为损失函数,这里定义loss为均方差。

定义优化器optimizer,其中学习率为0.5,学习率过小会导致训练速度慢,学习率过高会造成拟合参数在最优解附近“左右摇摆”。

定义async函数(Generator 函数的语法糖)train,train函数内根据迭代步数及学习率调用优化器并计算损失函数loss。

函数generateData随机生成[-1, 1]范围内的点,并根据传入的a、b、c、d加上一定的随机扰动生成数据点xs,ys,其中ys进行归一化处理。

函数plotData将随机生成的样本点映射为散点图,将根据训练后的参数拟合出的点映射为曲线。

函数renderCoefficients将a、b、c、d的值输出到document内。

函数learnCoefficients是index.js的main函数,函数内先设定预定义的a、b、c、d,再生成有误差的训练数据,利用训练数据训练a、b、c、d参数并参数输出到文档,之后利用训练好的参数拟合x数据,将结果绘制为散点图及曲线,最后通知GC清理。

此时拟合出的曲线图会有bug,如下所示。


原因分析:

传入echarts的点对是按生成顺序排序的,是无序数组,但绘制曲线时是按传入数组的顺序连接各点,因此在传入前需对二维数据进行排序。在curveChart.setOption前加入如下代码。

// 二维数组排序
valuesAfter.sort(function(x, y) {
  return x[0] - y[0];
});

结果如下。


完整程序见我的github,具体步骤为:

step1 新建文件夹,cmd输入git clone [email protected]:orangecsy/tfjs-exercise.git,cd 1进入文件夹1;

step2 cmd输入webpack,打包;

step3 cd dist进入dist文件夹,cmd中输入http-server(需先npm install http-server)或使用webpack配置开发服务器;

step4 浏览器中输入http://127.0.0.1:8080/,即为结果。

猜你喜欢

转载自blog.csdn.net/orangecsy/article/details/80110663