mxnet3——线性回归

只利用ndarray 和 autograd

生成训练集

X = nd.random_normal(shape=(num_examples, num_inputs))
y = true_w[0] * X[:, 0] + true_w[1] * X[:, 1] + true_b
y += .01 * nd.random_normal(shape=y.shape)

def data_iter():
    idx = list(range(num_examples))
    for i in range(0, num_examples, batch_size):
        j = nd.array(idx[i:min(i+batch_size,num_examples)])
  • datch_size 每次从训练集抽取的数目

模型

def net(X):
    return nd.dot(X, w) + b

损失函数

def square_loss(yhat, y):
    return (yhat - y.reshape(yhat.shape)) ** 2
  • 输出值和真实值的区别

优化函数

def SGD(params, lr):
    for param in params:
        param[:] = param - lr * param.grad
  • 按梯度的反方向,走lr个步长,参数改变

迭代

for e in range(epochs):    
    total_loss = 0
    print ('e',e);
    for data, label in data_iter():
        with autograd.record():
            output = net(data)
            loss = square_loss(output, label)
        loss.backward()
        SGD(params, learning_rate)
        total_loss += nd.sum(loss).asscalar()
        niter +=1
        curr_loss = nd.mean(loss).asscalar()
        moving_loss = (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss

        est_loss = moving_loss/(1-(1-smoothing_constant)**niter)
  • epochs:迭代次数。每一次迭代中,按batchsize的大小取训练集中的数据,直到取完。

mxnet

训练集

X = nd.random_normal(shape=(num_examples,num_inputs))
y = true_w[0] * X[:,0] + true_w[1] * X[:,1] + true_b
y += .01 * nd.random_normal(shape=y.shape)

batch_size = 10
dataset = gluon.data.ArrayDataset(X,y)
data_iter = gluon.data.DataLoader(dataset,batch_size,shuffle=True)

网络

net  = gluon.nn.Sequential() # 建立一个空的模型
net.add(gluon.nn.Dense(1)) # 加入一个Dense,输出节点为1
net.initialize() # 初始化参数

损失函数

square_loss = gluon.loss.L2Loss()

优化函数

trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1})

迭代

for e in range(epochs):
total_loss = 0
for data,label in data_iter:
    with autograd.record():
        output = net(data)
        loss = square_loss(output,label)
    loss.backward()
    trainer.step(batch_size)
    total_loss += nd.sum(loss).asscalar()

猜你喜欢

转载自blog.csdn.net/sda42342342423/article/details/78897372