KNN算法实战



"""
@Time    : 2021/1/4 20:37
@File    : KNN.py
"""

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier


def knn_iris():
    """
    用knn对鸢尾花进行分类
    :return:
    """
    # 1、获取数据
    iris = load_iris()

    # 2、划分数据集
    x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=6)

    # 3、特征工程:标准化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)

    # 4、knn算法的预估器
    estmator = KNeighborsClassifier(n_neighbors=2)
    estmator.fit(x_train, y_train)

    # 5、模型的评估
    # 方法一:直接对比真实值和预测值
    y_predict = estmator.predict(x_test)
    print("y_predict:\n", y_predict)
    print("直接对比预测值和真实值:\n", y_predict == y_test)
    # 方法二:计算准确率
    score = estmator.score(x_test, y_test)
    print("准确率为:\n", score)

    return None


if __name__ == "__main__":
    knn_iris()

猜你喜欢

转载自blog.csdn.net/weixin_44010756/article/details/112203271