# Copyright 2015 The TensorFlow Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""MNIST handwritten digits dataset.
"""from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import gzip
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.datasets.mnist.load_data')defload_data(data_folder='/home/ruoke/python/ai/datasets/mnist/'):
files =['train-labels-idx1-ubyte.gz','train-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz']
paths =[]for fname in files:
paths.append(os.path.join(data_folder,fname))with gzip.open(paths[0],'rb')as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[1],'rb')as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train),28,28)with gzip.open(paths[2],'rb')as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[3],'rb')as imgpath:
x_test = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_test),28,28)return(x_train, y_train),(x_test, y_test)