SVM算法的一点点理解 SVM算法的一点点理解

SVM算法的一点点理解

最近几天看了下SVM算法,下面是我的个人理解。 
SVM(支持向量机)是为了找一个超平面,使得几何间隔最大化,为了方便讨论,下面只考虑二维的情况,并且数据线性可分且数据的种类只有两种(即二分类问题),这样问题就是找一个直线使得这条直线能够把正负两类点分割,并且使得所有数据的最小几何间隔最大。 
这里写图片描述 
这里假设数据为(x1,y1),(x2,y2)…….(xn,yn) 
可以把数据表示为坐标轴上的点。 
我们定义函数间隔为yi*(wxi+b) 
几何间隔为yi*(wxi+b)/||w|| 
我们定义最小的函数间隔为T1 = min(所以数据的函数间隔) 
最小的几个间隔为T2 = min(所以数据的几何间隔) 
那么SVM算法求的就是 
max T2 
st yi*(wxi+b)/||w||>=T2 
可以写成 
max T1/||w|| 
st yi*(wxi+b)/||w||>=T2/||w|| 
这里我们注意到函数间隔这个约束是可以任意改变的,比如缩小或者放大N倍,这里可以把T1置为1 
则 max 1/||w|| 
st yi*(wxi+b)>=1

即 min ·1/2||w||^2 
st yi*(wxi+b)>=1 
运用朗格朗日 
L(a,w,b) = 1/2||w||^2-sigma(ai(yi*(wxi+b)-1)) 
对w和b求偏导数带入得到对偶形势 
这里写图片描述

下面我们来讨论下数据线性不可分的情况。 
当数据在二维线性不可分的情况下映射到三维,或者更高维可能就线性可分了。 
下面的两种核函数 
这里写图片描述

还有一种处理线性不可分的方法-软间隔 
我们允许分类器出现误差。 
即 min 1/2||w||+C*sigma(mi) 
st yi*(w*xi)>=1-mi 
可以理解为给分类器一个犯错的允许,但是每个误差都是要付出代价的

同样运用拉格朗日可以得到对偶形式 
这里写图片描述 
其中 
这里写图片描述

如何求得这个问题的解是NP难问题

但是微软研究院的以为大神的SMO算法可以逼近最优

这里简单讲讲SMO算法

我们选取a1,a2作为要更新的对象,那么根据限制条件sigma(ai*yi) = 0,和0<=ai<=c可以得到a1 = (s-a2*y2)*y1;并且得到a2的取值范围,带入原式就是个一元问题。幸运的是原函数是个凸函数,可以直接求极值。

这里a1,a2的选取可以运用启发式搜素,这里就不说了。

这里是代码。
#include "stdafx.h"
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <iostream>
#include <algorithm>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <conio.h>
#include <math.h>
#include <dirent.h> 
#include <vector>
using namespace cv;
using namespace std;
#define  X_LEN 32*30
#define  S_LEN 64
#define  TOT 64
#define  C 1
#define  EPS 1e-4
#define W 3
#define MAXN 80
int x[S_LEN+1][X_LEN],y[S_LEN];
double a[S_LEN],miss[S_LEN],b;
double w[X_LEN],res[S_LEN];
vector<string> ve[2];
double pow2(double x)
{
    return x*x;
}
int eq(double a)
{
    if(fabs(a)<EPS) return 1;
    else return 0;
}
double kernel(int s1,int s2)
{
    /*int sum   = 0;
    for(int i = 0;i<X_LEN;i++)
    {
        sum+=x[s1][i]*x[s2][i];
    }
    return sum;*/
    double sum = 0;
    for(int i = 0;i<X_LEN;i++)
    {
        sum+=pow2(x[s1][i]-x[s2][i]);
    }
    return exp(-sum/2*W*W);
}

void cal_w()
{
    for(int i = 0;i<X_LEN;i++)
    {
        double sum  = 0;
        for(int j = 0;j<S_LEN;j++)
        {
             sum+=a[j]*y[j]*x[j][i];
        }
        w[i] = sum;
    }
}

double cal_res(int s)
{
    double sum =  0;
    for(int i = 0;i<S_LEN;i++)
    {
        sum+=a[i]*y[i]*kernel(i,s);
    }
    return sum+b;
}
double cal_miss(int s)
{
      return cal_res(s)-y[s];
}

int choose1()
{
/*  for(int i=0;i<S_LEN;i++)
    {
        if(a[i]>0&&a[i]<C&&!eq(y[i]*cal_res(i)-1)) return i;
    }
    for(int i=0;i<S_LEN;i++)
    {
        if(eq(a[i])&&y[i]*cal_res(i)-1<1-EPS) return i;
        if(eq(a[i]-C)&&y[i]*cal_res(i)-1>1-EPS) return i;
    }
*/  
    return rand()%S_LEN;
}
int choose2(int index1)
{
    /*int index2 = -1;
    double maxnum = -1;
    double m  = cal_miss(index1);
    for(int i = 0;i<S_LEN;i++)
    {
        if(i==index1) continue;
        if(fabs(m-cal_miss(i))>maxnum)
            maxnum = fabs(m-cal_miss(i)),index2 = i;
    }
    */
    //return index2;
    int temp = rand()%S_LEN;
    while(temp==index1) temp = rand()%S_LEN;
    return temp;
}

void svm()
{
    int index1 = choose1();
    int index2 = choose2(index1);


    int i,j,k;
    //printf("123ssdfsfsdf1231\n");
    double sum = 0;
    double L,H;
    for(i=0;i<S_LEN;i++)
    {
        if(i!=index1&&i!=index2)
        {
            sum+=a[i]*y[i];
        }
    }
    sum*=-1;
//  printf("%lf\n",sum);
    if(y[index1]==y[index2])
    {
        if(y[index1]==-1)
        {
            H = min(C,(a[index1]+a[index2]));
            L = max(0,(a[index1]+a[index2])-C);
        }
        else
        {
            H = min(C,a[index1]+a[index2]);
            L = max(0,a[index1]+a[index2]-C);
        }
    }
    else
    {
        if(y[index2]==-1)
        {
            swap(index1,index2);
        }
        H = min(C,C+a[index2]-a[index1]);
        L = max(0,a[index2]-a[index1]);

    }
    //printf("%lf %lf\n",L,H);

    for(i=0;i<S_LEN;i++)
    {
        miss[i] = cal_miss(i);
    }
    double temp = kernel(index1,index1)+kernel(index2,index2)-2*kernel(index1,index2);
    double temp_a = a[index2]+y[index2]*(miss[index1]-miss[index2])/temp;
    //printf("%lf %lf %lf %d %d %lf\n",a[index2],miss[index1],miss[index2],index1,index2,temp_a);
    if(temp_a>=L&&temp_a<=H) a[index2] = temp_a;
    else
    {

        double v1 = cal_res(index1)-b-a[index1]*y[index1]*kernel(index1,index1)-a[index2]*y[index2]*kernel(index1,index2);
        double v2 = cal_res(index2)-b-a[index1]*y[index1]*kernel(index2,index1)-a[index2]*y[index2]*kernel(index2,index2);
        double s1 = 0.5*kernel(index1,index1)*pow2(sum-L*y[index2])+0.5*kernel(index2,index2)*pow2(L)
                  +y[index2]*kernel(index1,index2)*(sum-L*y[index2])*L-(sum-L*y[index2])*y[index1]-
                  L+v1*(sum-L*y[index2])+y[index2]*v2*L;
        double s2 = 0.5*kernel(index1,index1)*pow2(sum-H*y[index2])+0.5*kernel(index2,index2)*pow2(H)
                  +y[index2]*kernel(index1,index2)*(sum-H*y[index2])*H-(sum-H*y[index2])*y[index1]-
                  H+v1*(sum-H*y[index2])+y[index2]*v2*H;
        if(s1<s2) a[index2] = L;
        else a[index2] = H;
    }
    a[index1] = (sum-y[index2]*a[index2])*y[index1];
    //printf("%d %d\n",index1,index2);
    if(a[index1]>0&&a[index1]<C)
    {
        //printf("11111111111\n");
        b = y[index1]-cal_res(index1)+b;
    }
    else if(a[index2]>0&&a[index2]<C)
    {
    //  printf("222222222222\n");
        b = y[index2]-cal_res(index2)+b;
    }
    else
    {
    //  printf("33333333333333333\n");
        b  =  (y[index1]-cal_res(index1)+2*b+y[index2]-cal_res(index2))/2;
    }
  // printf("%lf\n",a[index2]);
}

int check(double &pre)
{
    double sum  = 0;
    for(int i = 0;i<S_LEN;i++)
    {
        for(int j = 0;j<S_LEN;j++)
        {
            sum+=a[i]*a[j]*y[i]*y[j]*kernel(i,j);
        }
    }
    sum*=0.5;
    for(int i = 0;i<S_LEN;i++)
    {
        sum-=a[i];
    }
    //printf("%lf %lf\n",pre,sum);
    //printf("1111111,%lf\n",pre-sum);
    if(pre-sum>EPS) {
        pre = sum;
        return 1;
    }
    else {
        pre = sum;
        return 0;
    }

}
void readdir()
{
      DIR *directory_pointer;
      directory_pointer = opendir("d://ccut//faces_4//an2i");
      struct dirent *entry; 
      int i;
      while((entry = readdir(directory_pointer))!=NULL)
      {
          string s = "d://ccut//faces_4//an2i//";
          if((*entry).d_name[0]=='.') continue;
          s+=(*entry).d_name;
          ve[0].push_back(s);
         // cout<<s<<endl;
          //char name[MAXN];
         /* for(i=0;i<s.length();i++)
          {
              name[i] = s[i]; 
          }
          name[s.length()] = 0;
          */
          /* DIR * directory;
          directory = opendir(name);
          struct dirent *entry_tmp;
          string temp  = s;
          while((entry_tmp = readdir(directory))!=NULL)
          {
               s = temp;
               if((entry_tmp->d_name)[0]=='.') continue;
               s+="//";
               s+=entry_tmp->d_name;
               path[p_flag] = s;
               path_name[p_flag++]+= entry_tmp->d_name; 
          }                

          closedir(directory);
         */
      }
      directory_pointer = opendir("d://ccut//faces_4//at33");
      while((entry = readdir(directory_pointer))!=NULL)
      {
          string s = "d://ccut//faces_4//at33//";
          if((*entry).d_name[0]=='.') continue;
          s+=(*entry).d_name;
          ve[1].push_back(s);
      }
      closedir(directory_pointer);
}
void init()
{
    //printf("%d %d\n",ve[0].size(),ve[1].size());
    //printf("1111111\n");
    readdir();
    //printf("1111111");
    b  = 0;
//  printf("%d %d\n",ve[0].size(),ve[1].size());
    for(int i = 0;i<S_LEN;i++)
    {
        a[i] = 0;
    }
    Mat mat;
    int cnt = 0;
    for(int i = 0;i<ve[0].size();i++)
    {
         y[i] = 1;
        mat=  imread(ve[0][i]);
        cnt = 0;
        for(int j = 0;j<mat.rows;j++)
        {
            for(int k = 0;k<mat.cols;k++)
            {
                x[i][cnt++] = (int)(mat.at<uchar>(j,k));
            }
        }
    }
    for(int i = 0;i<ve[1].size();i++)
    {
        cnt = 0;
        y[32+i] = -1;
        mat=  imread(ve[1][i]);
        for(int j = 0;j<mat.rows;j++)
        {
            for(int k = 0;k<mat.cols;k++)
            {
                x[32+i][cnt++] = (int)(mat.at<uchar>(j,k));
            }
        }
    }
    //printf("%d\n",cnt);
    /*mat[0] = imread("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_open_4.pgm");
    mat[1] = imread("D:\\ccut\\faces_4\\an2i\\an2i_right_angry_open_4.pgm");
    mat[2] = imread("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_sunglasses_4.pgm");
    y[0] = 1;
    y[1] = -1;
    y[2] = 1;
    for(int i = 0;i<S_LEN;i++)
    {
        int cnt = 0;
        for(int j = 0;j<mat[i].rows;j++)
        {
            for(int k = 0;k<mat[i].cols;k++)
            {
                x[i][cnt++] = (int)(mat[i].at<uchar>(i,j));
            }
        }
    }
    /*
    //printf("%d %d\n",mat1.rows,mat1.cols);
    x[0][0] = 1,x[0][1] = 1,y[0] = 1;
    x[1][0] = 1,x[1][1] = 2,y[1] = -1;
    x[2][0] = 1,x[2][1] = 3,y[2] = 1;
    x[3][0] = 1,x[3][1] = 4,y[3] = -1;
    */
}

int _tmain(int argc, _TCHAR* argv[])
{
    srand(time(0));
    init();      
    double pre = 0;
    int flag = 0;
    while(flag<=200)
    {
        double s = pre;
        svm();
        cal_w();
        /*for(int i=0;i<S_LEN;i++)
        {
             printf("%lf ",a[i]);;
        }
        printf("\n");
        for(int i = 0;i<X_LEN;i++)
        {
            printf("%lf ",w[i]);
        }
        printf("%lf\n",b);
        printf("%lf\n",pre);
        */
        flag+=!check(pre);
        printf("%d\n",flag);

    }
    //FILE *f = fopen("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_open_4.pgm","rb");
    //Mat mat1 = imread("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_open_4.pgm");

    //waitKey(111111);
    for(int i = 0;i<=63;i++)
    {
        printf("%lf\n",cal_res(i));
    }
    //printf("%lf\n",cal_res(63));
    return 0;

}

猜你喜欢

转载自blog.csdn.net/qq_40909394/article/details/80188185
今日推荐