在浏览器中进行深度学习:TensorFlow.js (二)第一个模型,线性回归

笔者在上一篇文章中介绍了TensorFlow.js中的基本概念,以及机器学习的数学基础,线性代数的基本知识。在这一遍文章里,我们来看一看如何利用TensorFlow.js来构建数学模型,以及进行学习的基本过程。

学习的过程基本如下:

  1. 准备训练数据
  2. 构建一个模型
  3. 利用训练数据和模型,进行迭代的学习
  4. 模型训练完毕,用这个模型对新的数据进行预测(这里我们先略过对模型的验证部分)

好了,我们以最简单的线性回归为例子,看看这个过程。

准备数据

如上图所示,我在二维坐标系中生成了7个点,让它们在我假想的某条直线附近。我以这几个点作为我的训练数据。

训练数据的初始化代码如下,这里tx是所有点数据的x坐标,ty是所有点数据的坐标。

const train_x = tf.tensor1d(tx);
const train_y = tf.tensor1d(ty);

模型选择

所有的模型都是错的,有的模型更好。

所谓的模型,也就是一个函数f,对应于某个输入数据,计算出某些输出数据。模型可以复杂,可以简单。简单的模型不一定不好,负责的模型也不一定好。

我们用线性模型举例,数学上就是假定 Y = wX + b

在这个模型中,有两个参数需要确定,w和b。

模型既然是个函数,那么它的代码也就很容易理解了:

扫描二维码关注公众号,回复: 17832 查看本文章
const f = x => w.mul(x).add(b);

当然你也可以这样写:

const f = function(x){
  return w.mul(x).add(b);
  }
}

迭代学习

学习的过程我们称作训练,训练通常是一个迭代的过程,这个过程中,通常需要这几样东西:

  • 一个损失函数(loss function),损失函数定义了模型是不是足够好,通常loss越小越好。
  • 一个优化器 (optimizer),优化器通过某种算法来决定如何改变参数的值,使得损失函数最小化。
  • 迭代循环, 通过循环 -> 调用优化器,得到新的参数,计算损失, 最终当损失足够小时,可以认为训练结束了。

训练代码如下:

初始化参数,这里使用随机数来作为参数的初始值。(注意,初始参数并不总是随机选择的。)

const w = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random())); 

初始化学习参数,

  • numIterations是迭代的次数,一般次数越多,模型的拟合就越好,但是就需要花费越多的计算
  • learningRate是学习率,这个值越大,学的速度就越快,但是也会更加容易错过极值点。
const numIterations = 200;
const learningRate = 1;

选择一个优化器,这里我选择了adam。TensorFlow.js提供了多种优化器,例如sgd,momentum等等,大家可以根据自己的需要来选择。

const optimizer = tf.train.adam(learningRate);

对于损失函数,我们采用的是均方差 

const loss = (pred, label) => pred.sub(label).square().mean();

或者可以写作:

function loss(predictions, labels) {
  const meanSquareError = predictions.sub(labels).square().mean();
  return meanSquareError;
}

然后就是训练的过程啦:

for (let iter = 0; iter < numIterations; iter++) {
    optimizer.minimize(() => {
      const loss_var = loss(f(train_x), train_y);
      loss_var.print();
      return loss_var;
    })
}

在训练过程中,我们调用tensor的print()方法打印出损失的值,看看训练过程是不是收敛。当选择的模型,参数,优化器不合适的时候,有可能训练过程并不收敛。

训练的结果我们就等到了w和b的值。也就是确定了直线的斜率和截距。
大家可以尝试我的演示代码

我们可以看到学习过程中是如何慢慢收敛到最后的结果的直线。

总结

本文描述了一个使用tensoflow.js来进行最简单的线性回归模型的学习的过程。希望大家可以通过这个简单的例子了解机器学习的基本思路。

参考

猜你喜欢

转载自my.oschina.net/taogang/blog/1793835