AI - TensorFlow - 示例01:基本分类

基本分类

基本分类(Basic classification):https://www.tensorflow.org/tutorials/keras/basic_classification

Fashion MNIST数据集

tf.keras

是一种用于在TensorFlow中构建和训练模型的高阶API:https://www.tensorflow.org/api_docs/python/tf/keras/

示例

脚本内容

xxx

运行结果

xxx

问题处理

问题1:执行fashion_mnist.load_data()失败

错误提示
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
......
Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz: None -- [WinError 10060] A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond

处理方法1

选择一个链接,

手工下载下面四个文件,并存放在“~/.keras/datasets”下的fashion-mnist目录。

  • train-labels-idx1-ubyte.gz
  • train-images-idx3-ubyte.gz
  • t10k-labels-idx1-ubyte.gz
  • t10k-images-idx3-ubyte.gz
guowli@5CG450158J MINGW64 ~/.keras/datasets
$ pwd
/c/Users/guowli/.keras/datasets

guowli@5CG450158J MINGW64 ~/.keras/datasets
$ ls -l
total 0
drwxr-xr-x 1 guowli 1049089 0 Mar 27 14:44 fashion-mnist/

guowli@5CG450158J MINGW64 ~/.keras/datasets
$ ls -l fashion-mnist/
total 30164
-rw-r--r-- 1 guowli 1049089  4422102 Mar 27 15:47 t10k-images-idx3-ubyte.gz
-rw-r--r-- 1 guowli 1049089     5148 Mar 27 15:47 t10k-labels-idx1-ubyte.gz
-rw-r--r-- 1 guowli 1049089 26421880 Mar 27 15:47 train-images-idx3-ubyte.gz
-rw-r--r-- 1 guowli 1049089    29515 Mar 27 15:47 train-labels-idx1-ubyte.gz

处理方法2

手工下载文件,存放在指定目录。
改写“tensorflow\python\keras\datasets\fashion_mnist.py”定义的load_data()函数。

from tensorflow.python.keras.utils import get_file
import numpy as np
import pathlib
import gzip


def load_data():  # 改写“tensorflow\python\keras\datasets\fashion_mnist.py”定义的load_data()函数
    base = "file:///" + str(pathlib.Path.cwd()) + "\\"  # 当前目录

    files = [
        'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
        't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
    ]

    paths = []
    for fname in files:
        paths.append(get_file(fname, origin=base + fname))

    with gzip.open(paths[0], 'rb') as lbpath:
        y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(paths[1], 'rb') as imgpath:
        x_train = np.frombuffer(
            imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)

    with gzip.open(paths[2], 'rb') as lbpath:
        y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(paths[3], 'rb') as imgpath:
        x_test = np.frombuffer(
            imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)

    return (x_train, y_train), (x_test, y_test)


(train_images, train_labels), (test_images, test_labels) = load_data()

问题2:使用gzip.open()打开.gz文件失败

错误提示

“OSError: Not a gzipped file (b'\n\n')”

处理方法

对于损坏的、不完整的.gz文件,zip.open()将无法打开。检查.gz文件是否完整无损。

参考信息

https://github.com/tensorflow/tensorflow/issues/170

猜你喜欢

转载自www.cnblogs.com/anliven/p/10612178.html