KNN实现(数据集cifar10)

1. 读取数据集

import pickle

with open('data_batch_2', 'rb') as f:
    #x = pic.load(f, encoding='bytes')
    x = pickle.load(f, encoding='latin1')
    print(x['data'].shape)

#shape(10000, 3072)
  • cifar数据集是用pickle序列化保存,读取方式python2和python3不同,此处采用的python3。encoding可以是bytes,也可以是latin1,目前还没搞懂这是为什么。
def cifarLoad():
    file = 'data_batch_'
    train_data = []
    train_label = []
    val_data = []
    val_label = []
    for i in range(1, 6):
        filename = file + str(i)
        data_batch = unpickle(filename)
        train_data.extend(list(data_batch['data'])[0:9000])
        list(data_batch['data'])
        train_label.extend(data_batch['labels'][0:9000])
        val_data.extend(data_batch['data'][9000:, :])
        val_label.extend(data_batch['labels'][9000:])

    return np.array(train_data), np.array(train_label), np.array(val_data), np.array(val_label)
  • 分成验证集和训练集(本次没有采用交叉验证,后面会采用交叉验证再试一次)。
class NearestNeighbor(object):
    def __init__(self):
        self.X = None
        self.y = None
        self.dist = 0

    def train(self, x, y):
        self.xtr = x
        self.ytr = y

    def predict(self, test_X, k, distance):
        num_test = test_X.shape[0]
        pre = []
        for i in range(num_test):
            if distance == 'L1':
                self.dist = self.L1Distance(test_X[i])
            if distance == 'L2':
                self.dist = self.L2Distance(test_X[i])
            distArgSort = np.argsort(self.dist)[0:k]
            classSort = self.ytr[distArgSort]
            classCount = np.bincount(classSort)
            predict = np.argmax(classCount)
            #print(predict.dtype)
            pre.append(predict)
        return np.array(pre)

    def L1Distance(self, x):
        dist = np.sum(abs(self.xtr-x), axis=1)
        return dist

    def L2Distance(self, x):
        dist = np.sqrt(np.sum(np.square(self.xtr-x), axis=1))
        return dist
  • KNN代码
import pickle
import numpy as np

def unpickle(filename):
    with open(filename, 'rb') as f:
        cifar = pickle.load(f, encoding='latin1')
    return cifar

def cifarLoad():
    file = 'data_batch_'
    train_data = []
    train_label = []
    val_data = []
    val_label = []
    for i in range(1, 6):
        filename = file + str(i)
        data_batch = unpickle(filename)
        train_data.extend(list(data_batch['data'])[0:9000])
        list(data_batch['data'])
        train_label.extend(data_batch['labels'][0:9000])
        val_data.extend(data_batch['data'][9000:, :])
        val_label.extend(data_batch['labels'][9000:])

    return np.array(train_data), np.array(train_label), np.array(val_data), np.array(val_label)

class NearestNeighbor(object):
    def __init__(self):
        self.X = None
        self.y = None
        self.dist = 0

    def train(self, x, y):
        self.xtr = x
        self.ytr = y

    def predict(self, test_X, k, distance):
        num_test = test_X.shape[0]
        pre = []
        for i in range(num_test):
            if distance == 'L1':
                self.dist = self.L1Distance(test_X[i])
            if distance == 'L2':
                self.dist = self.L2Distance(test_X[i])
            distArgSort = np.argsort(self.dist)[0:k]
            classSort = self.ytr[distArgSort]
            classCount = np.bincount(classSort)
            predict = np.argmax(classCount)
            #print(predict.dtype)
            pre.append(predict)
        return np.array(pre)

    def L1Distance(self, x):
        dist = np.sum(abs(self.xtr-x), axis=1)
        return dist

    def L2Distance(self, x):
        dist = np.sqrt(np.sum(np.square(self.xtr-x), axis=1))
        return dist

def CrossValidation():
    pass



if __name__ == '__main__':
    train_data, train_label, val_data, val_label = cifarLoad()
    clf = NearestNeighbor()
    train = clf.train(train_data, train_label)
    pre = clf.predict(val_data, k=20, distance='L2')
    arr = np.mean(pre == val_label)
    print(arr)

猜你喜欢

转载自blog.csdn.net/ncc1995/article/details/84836450