第二章 感知机代码实现

对于代码实现过程中的一点思考:

  • 如何将公式转换为代码?  直接通过公式转换。

  • 感知机的前提:数据为线性可分,即数据在高维空间中,像可以被一刀切成两部分。一要为线性可分。这种限定条件太严格了,如果数据不为线性可分的话,肯定会出现准确率问题。二是只能二分类,现实中二分类问题太少,但是可以通过多个二分类问题达到多分类问题。
  • 其次数据集采用的是Mnist手写数据集,但是也没用它来作为多分类器,只是简单写了一个二分类。
  • 思路简单,但是只是作为神经网络和支持向量机的基础来看待。

对于感知机的一些思考:

  • 手写的二分类实现过程太慢了,应该是数据量大吧,Mnist60000个样本,一个样本为784维的向量。然后进行1000次的迭代。确实计算量有点大。
  • 其次,我查看了我的lost变化情况,在迭代25次后就基本不变了,见下图。

  •  为什么我的lost更新的幅度这么小,其次我的lost在后面的迭代中基本不变,如何解决?

代码如下所示:

import matplotlib.pyplot as plt
import numpy as np
import time


def read_data(path):
    """
    读取数据
    :param path:路径
    :return: train data,train label
    """
    data_array = []
    label_array = []
    fr = open(path, 'r')
    for line in fr.readlines():
        # 以逗号分开数据
        curline = line.strip().split(',')
        # 由于感知机只为二分类,故的对数据进行处理,让其只有二种
        if int(curline[0]) > 5:
            label_array.append(1)
        else:
            label_array.append(-1)
        # 添加数据
        data_array.append([int(num) / 255 for num in curline[1:]])
    return data_array, label_array


def perceptron(x, y, n, theta):
    """
    实现感知机算法
    :param x: 训练的数据
    :param y: 训练的label
    :param n: 迭代次数
    :return: w,训练好的模型,loss loss的变化
    """
    # 这样产生的是二维数组,必须降维
    w = np.random.randn(1, len(x[0]) + 1)
    w = w[0]
    lost = []
    # 在x的最后面加一列1
    x = np.insert(arr=x, obj=len(x[0]), values=1, axis=1)
    for i in range(n):
        tmp = 0
        for i in range(len(x)):
            if y[i] * np.dot(w, x[i].T) <= 0:
                tmp += -y[i] * np.dot(w, x[i].T)
                w += theta * y[i] * x[i]
        lost.append(tmp)
    return w, lost

def acc(x,y,w):
    amount=0.0
    all=float(len(x))
    x=np.insert(arr=x,obj=len(x[0]),values=1,axis=1)
    for i in range(len(x)):
        if y[i]*np.dot(w,x[i].T)>0:
            amount+=1.0
    return amount/all

if __name__ == '__main__':
    train_x, train_y = read_data("data//Mnist//mnist_train.csv")
    test_x, test_y = read_data("data//Mnist//mnist_test.csv")
    w, lost = perceptron(train_x, train_y, n=1000, theta=0.05)
    plt.plot(list(range(1000)), lost)
    plt.show()
    print(acc(test_x,test_y,w))
View Code

猜你喜欢

转载自www.cnblogs.com/fghfghfgh666/p/10914752.html