cifar10 数据的下载和使用

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u012193416/article/details/87267668

cifar 是重要的评测数据集

import cv2
import numpy as np

from keras.datasets import cifar10
from keras.utils import np_utils

nb_train_samples = 3000
nb_valid_samples = 100
num_classes = 10


def load_cifar10_data(img_rows, img_cols):
    (x_train, y_train), (x_valid, y_valid) = cifar10.load_data()
    print(x_train.shape, y_train.shape, x_valid.shape, y_valid.shape)
    # (50000, 32, 32, 3) (50000, 1) (10000, 32, 32, 3) (10000, 1)

    x_train = np.array([cv2.resize(img, (img_rows, img_cols)) for img in x_train[:nb_train_samples, :, :, :]])
    x_valid = np.array([cv2.resize(img, (img_rows, img_cols)) for img in x_valid[:nb_valid_samples, :, :, :]])
    print(x_train.shape, x_valid.shape)
    # (3000, 224, 224, 3) (100, 224, 224, 3)

    y_train = np_utils.to_categorical(y_train[:nb_train_samples], num_classes)
    y_valid = np_utils.to_categorical(y_valid[:nb_valid_samples], num_classes)
    print(y_train.shape, y_valid.shape)
    # (3000, 10) (100, 10)

    return x_train, y_train, x_valid, y_valid


if __name__ == '__main__':
    load_cifar10_data(224, 224)

上述代码是对 cifar 数据集的下载,cifar10.load_data( ) 会保存到默认的文件夹下。

import pickle
import numpy as np
import os


def load_file(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f, encoding='latin1')
        x = data['data']
        y = data['labels']
        x = x.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype('float')  # tf model input
        y = np.array(y)
        return x, y


def load_cifar10(ROOT):
    xs = []
    ys = []
    for b in range(1, 6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b,))
        x, y = load_file(f)
        xs.append(x)
        ys.append(y)
    x_train = np.concatenate(xs)
    y_train = np.concatenate(ys)
    x_test, y_test = load_file(os.path.join(ROOT, 'test_batch'))
    return x_train, y_train, x_test, y_test


if __name__ == '__main__':
    print(load_file('../cifar-10-batches-bin/cifar-10-batches-bin'))

上述代码是读取 cifar 数据集,并将5个bin文件合并成一个数据。

猜你喜欢

转载自blog.csdn.net/u012193416/article/details/87267668