错误:Fetch argument array has invalid type class 'numpy.ndarray'

出错代码: 

#创建会话(运行环境)
with tf.Session() as sess:
    #初始化全局变量
    sess.run(tf.global_variables_initializer())
    #开始训练模型
    #因为训练集较小,所以采用批梯度下降优化算法,每次都使用全量数据训练
    for e in range(1, epoch+1):
        sess.run(train_op, feed_dict = {X: x_data, Y: y_data})
        if e % 10 == 0:
            loss,W = sess.run([loss, W], feed_dict = {X: x_data, Y: y_data})
            log_str = "Epoch %d \t Loss = %.4g \t Model: y = %.4gx1 + %.4gx2 +%.4g"
            print(log_str % (e, loss, w[1], w[2], w[0]))

原因:倒数第三行,新的变量名不应与就变量名一样

修改

Loss,w = sess.run([loss, W], feed_dict = {X: x_data, Y: y_data})

猜你喜欢

转载自blog.csdn.net/lincoco49/article/details/88102807