19、TensorFlow 实现最近邻分类器(K=1)

一、KNN 分类模型的三要素

1、距离度量

这里写图片描述

2、K 值的选择

这里写图片描述

3、分类决策规则

这里写图片描述


二、TF 实现最近邻分类器(K=1)

import tensorflow as tf
import numpy as np
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '6'  # 使用第七块卡
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  


# 导入MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("mnist_data/", one_hot=True)


# 我们对MNIST数据集做一个数量限制,取一小部分来做试验
Xtrain, Ytrain = mnist.train.next_batch(5000)  # 5000 训练样例(nn candidates)
Xtest, Ytest = mnist.test.next_batch(200)      # 200 用于测试
print('Xtrain.shape: ', Xtrain.shape, ', Xtest.shape: ',Xtest.shape)
print('Ytrain.shape: ', Ytrain.shape, ', Ytest.shape: ',Ytest.shape)
# ('Xtrain.shape: ', (5000, 784), ', Xtest.shape: ', (200, 784))
# ('Ytrain.shape: ', (5000, 10), ', Ytest.shape: ', (200, 10))


# 计算图输入占位符(train 使用全部样本,test 逐个样本进行测试)
xtrain = tf.placeholder(tf.float32, shape=[None, 784])
xtest = tf.placeholder(tf.float32, shape=[784])


# 使用 L1 距离进行最近邻计算, 计算 distance 时 xtest 会进行广播操作
distance = tf.reduce_sum(tf.subtract(xtrain, xtest)), axis=1)


# 预测: 获得最小距离的索引,然后根据此索引的类标和正确的类标进行比较
pred = tf.argmin(distance, axis=0)


# 需要多少 GPU 资源让它取多少(在第七块 GPU 上)
gpuConfig = tf.ConfigProto()
gpuConfig.gpu_options.allow_growth = True


# 初始化所有变量
init = tf.global_variables_initializer()

# 最近邻分类器的准确率
accuracy = 0.0

# 启动会话
with tf.Session(config=gpuConfig) as sess:
    sess.run(init)
    Ntest = len(Xtest)  #测试样本的数量

    # 在测试集上进行循环
    for i in range(Ntest):
        # 获取当前测试样本的最近邻
        nn_index = sess.run(pred, feed_dict={xtrain: Xtrain, xtest: Xtest[i, :]}) 

        # 获得最近邻预测标签,然后与真实的类标签比较,由于是 one_hot 编码,所以要用 argmax 将类标取出  
        pred_class_label = np.argmax(Ytrain[nn_index])
        true_class_label = np.argmax(Ytest[i])
        print("Test", i, "Predicted Class Label:", pred_class_label,
              "True Class Label:", true_class_label)

        # 计算准确率
        if pred_class_label == true_class_label:
            accuracy += 1
    print("Done!")
    accuracy /= Ntest
    print("Accuracy:", accuracy)

这里写图片描述


三、参考资料

1、TensorFlow实现最近邻分类器

猜你喜欢

转载自blog.csdn.net/mzpmzk/article/details/78647724
今日推荐