Tensorflow实现K近邻分类器

Tensorflow实现K近邻分类器

1、K近邻分类模型基本原理

首先,存在一个样本数据集合,也称作训练样本集,井且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输人没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。 一般来说,我们只选择样本数据集中前k个最相似的数据,这就是K近邻算法中k的出处,通常来说,k不大于20.最后选择k个最相似数据中出现次数最多的分类,作为新数据的分类。

K近邻分类模型的三个基本要素:

(1)距离度量

(2)K值的选择

(3)分类决策的规则

2、使用tensorflow实现k近邻分类器

import numpy as np
import os
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 导入MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("mnist_data/", one_hot=True)

# 我们对MNIST数据集做一个数量限制,
Xtrain, Ytrain = mnist.train.next_batch(5000) #5000 用于训练(nn candidates)
Xtest, Ytest = mnist.test.next_batch(200) #200 用于测试
print('Xtrain.shape: ', Xtrain.shape, ', Xtest.shape: ',Xtest.shape)
print('Ytrain.shape: ', Ytrain.shape, ', Ytest.shape: ',Ytest.shape)

# 计算图输入占位符
xtrain = tf.placeholder("float", [None, 784])
xtest = tf.placeholder("float", [784])

# 使用L1距离进行最近邻计算
# 计算L1距离
distance = tf.reduce_sum(tf.abs(tf.add(xtrain, tf.negative(xtest))), axis=1)
# 预测: 获得最小距离的索引 (根据最近邻的类标签进行判断)
pred = tf.argmin(distance, 0)
#评估:判断给定的一条测试样本是否预测正确

# 初始化节点
init = tf.global_variables_initializer()

#最近邻分类器的准确率
accuracy = 0.

# 启动会话
with tf.Session() as sess:
    sess.run(init)
    Ntest = len(Xtest)  #测试样本的数量
    # 在测试集上进行循环
    for i in range(Ntest):
        # 获取当前测试样本的最近邻
        nn_index = sess.run(pred, feed_dict={xtrain: Xtrain, xtest: Xtest[i, :]})
        # 获得最近邻预测标签,然后与真实的类标签比较
        pred_class_label = np.argmax(Ytrain[nn_index])
        j = Ytrain[nn_index]
        true_class_label = np.argmax(Ytest[i])
        print("Test", i, "Predicted Class Label:", pred_class_label,
              "True Class Label:", true_class_label)
        # 计算准确率
        if pred_class_label == true_class_label:
            accuracy += 1
    print("Done!")
    accuracy /= Ntest
    print("Accuracy:", accuracy)

有一点要说明,因为不可抗拒的因素,mnist数据集可能会下载失败,所以,需要提前下载下来放到‘mnist_data’文件夹下。

代码中有详细的注释,这里就不再赘述,运行结果如下:
这里写图片描述

猜你喜欢

转载自blog.csdn.net/haoronge9921/article/details/81037669