使用tensorflow2.0 训练数据时,出现Empty training data
报错:
File "gru_layer2.py", line 74, in <module>
network.fit(db_train, epochs=epochs, validation_data=db_val,steps_per_epoch=x_train.shape[0]//batchsz)
File "/homes/xiaohuizou/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "/homes/xiaohuizou/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 370, in fit
total_epochs=1)
File "/homes/xiaohuizou/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 180, in run_one_epoch
aggregator.finalize()
File "/homes/xiaohuizou/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py", line 140, in finalize
raise ValueError('Empty training data.')
ValueError: Empty training data.
检查我的代码:
batch_size=128
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(6000).batch(batchsz, drop_remainder=True).repeat ()
db_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
db_val = db_val.batch(batchsz, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz, drop_remainder=True)
> x_train.shape, x_test.shape,x_val.shape
(700, 100) (100, 100) (200, 100)
发现是因为测试集的sample size为100,而batch_size 设置为128,所以不够一个batch,报错,可以把batchsize 改到100以下batch_size=64
即可。