学习sklearn朴素贝叶斯

不同的贝叶斯假设数据的分布不同。
暂时全部使用默认参数

高斯朴素贝叶斯

"""
多项式朴素贝叶斯分类器适用于具有离散特征的分类(例如,用于文本分类的字数)。
多项分布通常需要整数特征计数。然而,在实践中,诸如tf-idf的分数计数也可以起作用。
"""
from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import GaussianNB
gnb=GaussianNB()
gnb.fit(iris.data,iris.target)
y_pred = gnb.predict(iris.data)
print('number of mislabeled points out of a total %d points : %d'%(iris.data.shape[0],(iris.target != y_pred).sum()))
gnb.score(iris.data,iris.target)
gnb.get_params()
number of mislabeled points out of a total 150 points : 6
{'priors': None}

多项式朴素贝叶斯

"""
多项式朴素贝叶斯分类器适用于具有离散特征的分类(例如,用于文本分类的字数)。多项分布通常需要整数特征计数。然而,在实践中,诸如tf-idf的分数计数也可以起作用。
"""
from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import MultinomialNB
gnb=MultinomialNB()
gnb.fit(iris.data,iris.target)
y_pred = gnb.predict(iris.data)
print('number of mislabeled points out of a total %d points : %d'%(iris.data.shape[0],(iris.target != y_pred).sum()))
gnb.score(iris.data,iris.target)
gnb.get_params()
number of mislabeled points out of a total 150 points : 7
{'alpha': 1.0, 'class_prior': None, 'fit_prior': True}

补充朴素贝叶斯

"""
#ComplementNB实现补充朴素贝叶斯(CNB)算法。CNB是标准多项式朴素贝叶斯(MNB)算法的改编,其特别适用于不平衡数据集。具体而言,CNB使用来自每个类的补集的统计来计算模型的权重。CNB的发明人凭经验证明,
#CNB的参数估计比MNB的参数估计更稳定。此外,CNB在文本分类任务上的表现通常优于MNB(通常相当大)

不知道为什么,这个算法运行不了。从sklearn库里删除了?
"""
from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import ComplementNB
gnb=ComplementNB()
gnb.fit(iris.data,iris.target)
y_pred = gnb.predict(iris.data)
print('number of mislabeled points out of a total %d points :%d'%(iris.data.shape[0],(iris.target != y_pred).sum()))
get.score(iris.data,iris.target)

伯努利朴素贝叶斯

"""
用于多变量伯努利模型的朴素贝叶斯分类器。
与MultinomialNB一样,该分类器适用于离散数据。不同之处在于,虽然MultinomialNB可以处理出现次数,但BernoulliNB设计用于二进制/布尔特征。
"""
from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import BernoulliNB
gnb=BernoulliNB()
gnb.fit(iris.data,iris.target)
y_pred = gnb.predict(iris.data)
print('number of mislabeled points out of a total %d points : %d'%(iris.data.shape[0],(iris.target != y_pred).sum()))
gnb.score(iris.data,iris.target)
#gnb.get_params
number of mislabeled points out of a total 150 points : 100
0.3333333333333333

猜你喜欢

转载自blog.csdn.net/qq_41205464/article/details/84445067
今日推荐