TensorFlow报错Fetch argument None has invalid type class 'NoneType'

版权声明:如有转载复制请注明出处,博主QQ715608270,欢迎沟通交流! https://blog.csdn.net/qq_41000891/article/details/84555225

写了一个TensorFlow卷积神经网络的训练程序。

基于mnist数据集进行训练和测试。

但是在程序运行的时候报出了下面的错误。

Traceback (most recent call last):
  File "nn_eg.py", line 104, in <module>
    train_loss, train_op = sess.run([loss, train_op], {input_x: batch[0], output_y: batch[1]})
  File "/home/zhonghangalex/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 929, in run
    run_metadata_ptr)
  File "/home/zhonghangalex/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1137, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home/zhonghangalex/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 471, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/zhonghangalex/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 261, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/zhonghangalex/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 370, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/zhonghangalex/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 370, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/zhonghangalex/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 258, in for_fetch
    type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>

 这里我们看到错误指向的是代码的104行,我将这部分的代码贴出来:

# 训练神经网络
for i in range(20000):
	batch = mnist.train.next_batch(50) #从Train(训练)数据集里取下一个50样本
	train_loss, train_op = sess.run([loss, train_op], {input_x: batch[0], output_y: batch[1]})
	if i % 100 == 0:
		test_accuracy = sess.run(accuracy, {input_x: test_x, output_y: test_y})
		print("Step=%d, Train loss=%.4f, [Test accuracy=%.2f]" % (i, train_loss, test_accuracy))

这是对神经网络训练的过程,指定训练20000步,有一个奇怪的现象就是,循环的第一步进行得很顺畅,可是从第二步开始就报了这个错误:

这就说明了应该是变量出现了问题。

查阅资料后发现是因为:

train_op变量重新分配给结果的第二个元素sess.run()(恰好是None)。因此,在第二次迭代中,train_op是None,这导致错误。 解决的方法很简单,就是把两个变量的第二个变量改为“_”:

# 训练神经网络
for i in range(20000):
	batch = mnist.train.next_batch(50) #从Train(训练)数据集里取下一个50样本
	train_loss, _ = sess.run([loss, train_op], {input_x: batch[0], output_y: batch[1]})
	if i % 100 == 0:
		test_accuracy = sess.run(accuracy, {input_x: test_x, output_y: test_y})
		print("Step=%d, Train loss=%.4f, [Test accuracy=%.2f]" % (i, train_loss, test_accuracy))

这样便成功开始了训练:

猜你喜欢

转载自blog.csdn.net/qq_41000891/article/details/84555225