贝叶斯推断及其互联网应用: 已知推断未知概率

已知推断未知概率, 也叫贝叶斯分类

先上问题吧,我们统计了14天的气象数据(指标包括outlook,temperature,humidity,windy),并已知这些天气是否打球(play)。如果给出新一天的气象指标数据:sunny,cool,high,TRUE,判断一下会不会去打球。

table 1

outlook temperature humidity windy play
sunny hot high FALSE no
sunny hot high TRUE no
overcast hot high FALSE yes
rainy mild high FALSE yes
rainy cool normal FALSE yes
rainy cool normal TRUE no
overcast cool normal TRUE yes
sunny mild high FALSE no
sunny cool normal FALSE yes
rainy mild normal FALSE yes
sunny mild normal TRUE yes
overcast mild high TRUE yes
overcast hot normal FALSE yes
rainy mild high TRUE no

这个问题可以用决策树的方法来求解,当然我们今天讲的是朴素贝叶斯法。这个一”打球“还是“不打球”是个两类分类问题,实际上朴素贝叶斯可以没有任何改变地解决多类分类问题。决策树也一样,它们都是有导师的分类方法。

朴素贝叶斯模型有两个假设:所有变量对分类均是有用的,即输出依赖于所有的属性;这些变量是相互独立的,即不相关的。之所以称为“朴素”,就是因为这些假设从未被证实过。

注意上面每项属性(或称指标)的取值都是离散的,称为“标称变量”。

step1.对每项指标分别统计:在不同的取值下打球和不打球的次数。

table 2

outlook temperature humidity windy play
yes no yes no yes no yes no yes no
sunny 2 3 hot 2 2 high 3 4 FALSE 6 2 9 5
overcast 4 0 mild 4 2 normal 6 1 TRUR 3 3
rainy 3 2 cool 3 1

step2.分别计算在给定“证据”下打球和不打球的概率。

这里我们的“证据”就是sunny,cool,high,TRUE,记为E,E1=sunny,E2=cool,E3=high,E4=TRUE。

A、B相互独立时,由:


得贝叶斯定理:


得:




又因为4个指标是相互独立的,所以


我们只需要比较P(yes|E)和P(no|E)的大小,就可以决定打不打球了。所以分母P(E)实际上是不需要计算的。

P(yes|E)*P(E)=2/9×3/9×3/9×3/9×9/14=0.0053

P(no|E)*P(E)=3/5×1/5×4/5×3/5×5/14=0.0206

所以不打球的概率更大。

零频问题

注意table 2中有一个数据为0,这意味着在outlook为overcast的情况下,不打球和概率为0,即只要为overcast就一定打球,这违背了朴素贝叶斯的基本假设:输出依赖于所有的属性。

数据平滑的方法很多,最简单最古老的是拉普拉斯估计(Laplace estimator)--即为table2中的每个计数都加1。它的一种演变是每个计数都u(0<u<1)。

Good-Turing是平滑算法中的佼佼者,有兴趣的可以了解下。我在作基于隐马尔可夫的词性标注时发现Good-Turing的效果非常不错。
对于任何发生r次的事件,都假设它发生了r*次:



nr是历史数据中发生了r次的事件的个数。

数值属性

当属性的取值为连续的变量时,称这种属性为“数值属性“。通常我们假设数值属性的取值服从正态分布。

outlook temperature humidity windy play
yes no yes no yes no yes no yes no
sunny 2 3 83 85 86 85 FALSE 6 2 9 5
overcast 4 0 70 80 96 90 TRUR 3 3
rainy 3 2 68 65 80 70
64 72 65 95
69 71 70 91
75 80
75 70
72 90
81 75
sunny 2/9 3/5 mean value 73 74.6 mean value 79.1 86.2 FALSE 6/9 2/5 9/15 5/14
overcast 4/9 0/5 deviation 6.2 7.9 deviation 10.2 9.7 TRUR 3/9 3/5

正态分布的概率密度函数为:


现在已知天气为:outlook=overcast,temperature=66,humidity=90,windy=TRUE。问是否打球?

f(温度=66|yes)=0.0340

f(湿度=90|yes)=0.0221

yes的似然=2/9×0.0340×0.0221×3/9×9/14=0.000036

no的似然=3/5×0.0291×0.0380×3/5×9/14=0.000136

不打球的概率更大一些。

用于文本分类

朴素贝叶斯分类是一种基于概率的有导师分类器。

词条集合W,文档集合D,类别集合C。

根据(1)式(去掉分母)得文档d属于类别cj的概率为:


p(cj)表示类别j出现的概率,让属于类别j的文档数量除以总文档数量即可。

而已知类别cj的情况下词条wt出现的后验概率为:类别cj中包含wt的文档数目  除以 类别cj中包含的文档总数目 。

结束语

实践已多次证明,朴素贝叶斯在许多数据集上不逊于甚至优于一些更复杂的分类方法。这里的原则是:优先尝试简单的方法。

机器学习的研究者尝试用更复杂的学习模型来得到良好的结果,许多年后发现简单的方法仍可取得同样甚至更好的结果。

实现代码:
Classifier.java
/**
 * 
 * 描述: 算法接口.
 * @author 
 *
 */
public interface Classifier {

    /**
     * 处理模型数据.
     * @param lable 标签名称.
     * @param value 标签值.
     * @param cnt 数量(该条数据的数量)
     * @param target 目标名称.
     * @param targetValue 目标值.
     */
    void train(String[] lable, String[] value, int cnt, String target, String targetValue);

    /**
     * 先验概率计算出其后验概率.
     * @param features 属性值.
     * @return 后验概率较大的数值.
     */
    String predict(String[] features);
}


NaiveBayes.java

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * 
 * 描述: 朴树贝叶斯算法.
 * @author 
 *
 */
public class NaiveBayes implements Classifier {
    private static final Logger LOG = LoggerFactory.getLogger(NaiveBayes.class);
    // 小数点后6位
    private static final int AFTER_POINT = 6;
    // 所有标签名称
    private String[] labelName;
    // 目标名称
    private String targetName;
    // 目标标签
    private String[] targetLabelName;
    // 标签列表
    private List<LabelBo> lstLabelBo;
    // 目标值列表
    private Set<String> setTargetVal;
    
    /**
     * 构造函数.
     */
    public NaiveBayes() {
        this.lstLabelBo = new ArrayList<LabelBo>();
        this.setTargetVal = new TreeSet<String>();
    }

    public void setLabelName(String[] labelName) {
        this.labelName = labelName;
    }
    
    public void setTargetLabelName(String[] targetLabelName) {
        this.targetLabelName = targetLabelName;
    }

    public void setTargetName(String targetName) {
        this.targetName = targetName;
    }
    
    public String[] getLabelName() {
        return labelName;
    }
    
    public List<LabelBo> getLstLabelBo() {
        return lstLabelBo;
    }
    
    public String[] getTargetLabelName() {
        return targetLabelName;
    }
    
    public String getTargetName() {
        return targetName;
    }
    
    public Set<String> getSetTargetVal() {
        return setTargetVal;
    }
    
    /**
     * 读文件.
     * @param path 路径.
     */
    public void readFile(String path) {
        if (null == this.labelName || null == this.targetName) {
            return;
        }
        
        try {
            BufferedReader reader = new BufferedReader(new FileReader(path));
            String line;
            boolean isTrue = false;
            while ((line = reader.readLine()) != null) {
                if ("@data".equals(line)) {
                    isTrue = true;
                    continue;
                }
                
                if (!isTrue) {
                    continue;
                }
                
                String[] atts = line.split(",");
                this.train(this.labelName, atts, 1,  this.targetName, atts[atts.length - 1]);
            }
            reader.close();
        } catch (FileNotFoundException ex) {
            LOG.error("Read naivebayes mode data failed, not found file, " + ex.getMessage());
        } catch (IOException ex) {
            LOG.error("Read naivebayes mode data failed, IO exception, " + ex.getMessage());
        }
    }
    
    @Override
    public void train(String[] lable, String[] value, int cnt, String target, String targetValue) {
        for (int i = 0; i < lable.length; i++) {
            LabelBo labelBo = null;
            for (LabelBo lb : this.lstLabelBo) {
                if (lable[i].equals(lb.getLableName()) && value[i].equals(lb.getItemName())) {
                    labelBo = lb;
                    break;
                }
            }

            if (null == labelBo) {
                labelBo = new LabelBo();
                labelBo.setLableName(lable[i]);
                labelBo.setItemName(value[i]);
                this.lstLabelBo.add(labelBo);
            }
            int index = labelBo.addName(targetValue);
            labelBo.addCount(index, cnt);
        }
    }

    /**
     * 计算比例.
     */
    public void rate() {
        Map<String, Integer> mapTotal = new HashMap<String, Integer>();
        for (LabelBo lb : this.lstLabelBo) {
            for (int i = 0; i < lb.getLstCount().size(); i++) {
                String tmp = lb.getLableName() + "." + lb.getLstName().get(i);
                if (mapTotal.containsKey(tmp)) {
                    mapTotal.put(tmp, mapTotal.get(tmp) + lb.getLstCount().get(i));
                } else {
                    mapTotal.put(tmp, lb.getLstCount().get(i));
                }
            }
        }
        for (LabelBo lb : this.lstLabelBo) {
            List<Integer> lst = lb.getLstTotal();
            for (int i = 0; i < lb.getLstName().size(); i++) {
                String tmp = lb.getLableName() + "." + lb.getLstName().get(i);
                lst.add(mapTotal.get(tmp));
            }
        }

        // 目标计算
        List<LabelBo> lstTmpLabelBo = new ArrayList<LabelBo>();
        for (LabelBo lb : this.lstLabelBo) {
            if (this.targetName.equalsIgnoreCase(lb.getLableName())) {
                lstTmpLabelBo.add(lb);
            }
        }

        int total = 0;
        for (LabelBo labelBo : lstTmpLabelBo) {
            if (null != labelBo) {
                for (int i = 0; i < labelBo.getLstCount().size(); i++) {
                    total += labelBo.getLstCount().get(i);
                    this.setTargetVal.add(labelBo.getLstName().get(i));
                }
            }
        }

        for (LabelBo labelBo : lstTmpLabelBo) {
            for (int i = 0; i < labelBo.getLstName().size(); i++) {
                labelBo.getLstTotal().set(i, total);
            }
        }
    }
    
    @Override
    public String predict(String[] features) {
        String score = "";
        double rate = 0;
        Set<String> lstTv = this.getSetTargetVal();
//        double total = 0;
        for (String v : lstTv) {
            String result = this.doPredict(this.targetLabelName, features, this.targetName, v);
            if (rate < Double.valueOf(result)) {
                rate = Double.valueOf(result);
                score =  v;
//                total += Double.valueOf(result);
            }
//            System.out.println(result + ":" +  v);
        }
        return score + ":" + rate;
    }

    /**
     * 计算后验概率.
     * @param lable 标签名称
     * @param features 标签值
     * @param target 目标名称
     * @param targetValue 目标值
     * @return 结果.
     */
    private String doPredict(String[] lable, String[] features, String target, String targetValue) {
        int pre = 1;
        int dev = 1;
        for (int i = 0; i < lable.length; i++) {
            LabelBo labelBo = null;
            for (LabelBo lb : this.lstLabelBo) {
                if (lable[i].equalsIgnoreCase(lb.getLableName()) && features[i].equalsIgnoreCase(lb.getItemName())) {
                    labelBo = lb;
                    break;
                }
            }
            
            if (null == labelBo) {
                continue;
            }

            List<String> lstName = labelBo.getLstName();
            for (String str : lstName) {
                if (targetValue.equals(str)) {
                    pre *= labelBo.getLstCount().get(lstName.indexOf(str));
                    dev *= labelBo.getLstTotal().get(lstName.indexOf(str));
                }
            }
        }

        LabelBo labelBo = null;
        for (LabelBo lb : this.lstLabelBo) {
            if (target.equalsIgnoreCase(lb.getLableName()) && targetValue.equalsIgnoreCase(lb.getItemName())) {
                labelBo = lb;
                break;
            }
        }

        if (null != labelBo) {
            List<String> lstName = labelBo.getLstName();
            for (String str : lstName) {
                if (targetValue.equals(str)) {
                    pre *= labelBo.getLstCount().get(lstName.indexOf(str));
                    dev *= labelBo.getLstTotal().get(lstName.indexOf(str));
                }
            }
        }

        BigDecimal result = new BigDecimal(pre).divide(new BigDecimal(dev), AFTER_POINT, BigDecimal.ROUND_HALF_UP);
        return result.toString();
    }
    
    /**
     * 重置.
     */
    public void reset() {
        this.lstLabelBo.clear();
        this.setTargetVal.clear();
    }
    
    /**
     * 打印数据.
     */
    public void print() {
        for (LabelBo key : this.lstLabelBo) {
            System.out.println(key.getLableName() + "=======>" + key.getItemName());
            List<String> lstName = key.getLstName();
            List<Integer> lstCount = key.getLstCount();
            List<Integer> lstTotal = key.getLstTotal();
            for (int i = 0; i < lstName.size(); i++) {
                System.out.println(lstName.get(i) + ":" + lstCount.get(i) + "/" + lstTotal.get(i));
            }
        }
    }
}


LabelBo.java

import java.util.ArrayList;
import java.util.List;


/**
 * 
 * 描述: 标签对象.
 * @author 
 *
 */
public class LabelBo {
    private String lableName;
    private String itemName;
    // 目标项对应的值.
    private List<String> lstName;
    private List<Integer> lstCount;
    private List<Integer> lstTotal;
    
    /**
     * 构造方法.
     */
    public LabelBo() {
        this.lstCount = new ArrayList<Integer>();
        this.lstName = new ArrayList<String>();
        this.lstTotal = new ArrayList<Integer>();
    }
    
    public void setLableName(String lableName) {
        this.lableName = lableName;
    }
    
    public String getLableName() {
        return lableName;
    }
    
    public void setItemName(String itemName) {
        this.itemName = itemName;
    }
    
    public String getItemName() {
        return itemName;
    }
    
    public List<String> getLstName() {
        return lstName;
    }
    
    public List<Integer> getLstCount() {
        return lstCount;
    }
    
    public List<Integer> getLstTotal() {
        return lstTotal;
    }
    
    /**
     * 添加标签对应的种类名称.
     * @param name 名称.
     * @return 下标.
     */
    public int addName(String name) {
        if (!this.lstName.contains(name)) {
            this.lstName.add(name);
        }
        return this.lstName.indexOf(name);
    }

    /**
     * 添加标签对应的种类名称的数量.
     * @param index 下标.
     * @param count 数量.
     */
    public void addCount(int index, Integer count) {
        if (this.lstCount.size() - 1 < index) {
            this.lstCount.add(count);
            return;
        }
        int temp = this.lstCount.get(index) + count;
        this.lstCount.set(index, temp);
    }

    public void setLstRate(List<Integer> lstTotal) {
        this.lstTotal = lstTotal;
    }
}

猜你喜欢

转载自horsemen.iteye.com/blog/2243699