文本相似度计算_BM25

 
 

BM 25也是计算TF、IDF、文档权重,只不过和经典的TFIDF表达式不同,参数也更多一些。

附上之前比赛的**版本代码,比赛结果还不错,不过如果时间充裕的话,可以把BM25和TextRank结合起来,效果会更好(在一篇论文里面看到的,有兴趣的可以去知网搜一下)

#!/usr/bin/python
#-*- coding:UTF-8-*-

import jieba
import jieba.posseg as pseg                 #引入结巴分词词性标注
import jieba.analyse
import numpy as np
import pandas
import pandas as pd
import csv
from gensim import corpora,models,similarities      #引入文本相似度库
from gensim.corpora import Dictionary
from gensim.models import TfidfModel,LdaModel
from pandas import DataFrame
from collections import defaultdict
import time

#=============================================训练集分词============================
#读取文件,主要用以生成词库

import math
class BM25(object):
    def __init__(self, docs):
        self.D = len(docs)
        self.avgdl = sum([len(doc)+0.0 for doc in docs]) / self.D
        self.docs = docs
        self.f = []  # 列表的每一个元素是一个dict,dict存储着一个文档中每个词的出现次数
        self.df = {} # 存储每个词及出现了该词的文档数量
        self.idf = {} # 存储每个词的idf值
        self.k1 = 1.5        #这个值一般取1.2 ,k1 和b 就是用来调优的
        self.b = 0.75     
        self.init()

    def init(self):
        for doc in self.docs:
            tmp = {}
            for word in doc:
                tmp[word] = tmp.get(word, 0) + 1  # 存储每个文档中每个词的出现次数
            self.f.append(tmp)
            for k in tmp.keys():
                self.df[k] = self.df.get(k, 0) + 1
        for k, v in self.df.items():
            self.idf[k] = math.log(self.D-v+0.5)-math.log(v+0.5)

    def sim(self, doc, index):
        score = 0
        for word in doc:
            if word not in self.f[index]:
                continue
            d = len(self.docs[index])
            score += (self.idf[word]*self.f[index][word]*(self.k1+1)
                      / (self.f[index][word]+self.k1*(1-self.b+self.b*d
                                                      / self.avgdl)))
        return score

    def simall(self, doc):
        scores = []
        for index in range(self.D):
            score = self.sim(doc, index)
            scores.append(score)
        return scores

def loadPoorEnt(path2 = './stopwords_suxue.tab'):

    poor_ent=set([])
    with open(path2, 'r') as ropen:
        lines=ropen.readlines()
        for line in lines:
            line=line.replace('\r','').replace('\n','')
            poor_ent.add(line)
    return poor_ent
stop_words=loadPoorEnt()

#读取数据
def extract_data(path,top):
    trainFile = pd.read_csv(path)  # 训练集文件
    trainFile=trainFile[:top]
    id=trainFile[['id']]
    title=trainFile[['title']]
    return id.values.tolist(),title.values.tolist()
    # mblogs = []  # 保存处理过的新闻
    # for card in tainFile:
    #     mblog = card['mblog']
    #     blog = {'mid': mblog['id'],  # 新闻id
    #             'text': clean_text(mblog['text']),  # 文本
    #             }
    #     mblogs.append(blog)
    # return mblogs

# """根据微id对新闻进行去重"""
def remove_duplication(mblogs):

    mid_set = {mblogs[0]['mid']}
    new_blogs = []
    for blog in mblogs[1:]:
        if blog['mid'] not in mid_set:
            new_blogs.append(blog)
            mid_set.add(blog['mid'])
    return new_blogs


#分词
def cut(data):
    result=[]    #pos=['n','v']
    for line in data:
        line=line[0]
        res=pseg.cut(line.strip())
        list=[]
        for item in res:
            if item.word.encode('utf8') not in stop_words :
            #if (len(item.word)>1 and item.word.encode('utf8') not in stop_words and (item.flag.startswith(u'n') or item.flag.startswith(u'v') or item.flag.startswith(u'm'))) or (len(item.word) == 1  and item.word.encode('utf8') not in stop_words and (item.flag.startswith(u'v') or item.flag.startswith(u'n'))):
                list.append(item.word)
                #print item.word
        result.append(list)
        #print "完成简单分词"
    return result

#去掉值出现过一次的词
def delete1time(data_cut):
    frequency = defaultdict(int)
    for text in data_cut:
        for token in text:
            frequency[token] += 1
    texts = [[token for token in text if frequency[token] > 1] for text in data_cut]
    return texts

def cal_time(time):
    if time<60:
        return str(time) + 'secs '
    elif time<60*60:
        return str(time/(60.0))+' mins '
    else:
        return str(time/(60*60.0))+' hours '

if __name__ == "__main__":

    start = time.clock()

    with open("sx_result510_bm25.txt", "a") as fr:
        fr.write('source_id')
        fr.write('\t')
        fr.write('target_id')
        fr.close

    news_length = 485687
    trainDataId,docs=extract_data('./train_dataSu.csv',news_length)
    #data_cut=cut(docs)

    print "训练集加载完成!"
    map_id={}
    for i,idd in enumerate(trainDataId):
        map_id[idd[0]]=i
    testDataId,test_docs=extract_data('./test_dataSu.csv',50)#50)
    #test_cut=cut(test_docs)
    print"测试集加载完成!"

    # 对训练集与测试集分词
    data_cut= cut(docs)
    trainDataSplit =delete1time(data_cut)
    print "已去除词频为1 的词"
    print('训练集分词完成')

    testDataSplit = cut(test_docs)
    print('测试集分词完成')

    m = len(testDataSplit)
    n = len(data_cut)
    simScores =[]
    for i in range(m):
        print('测试第%d条数据' % i)
        s = BM25(trainDataSplit)
        #for j in range(n):
        sim_scores = s.simall(testDataSplit[i])
        #simScores.append(sim_scores)
        simNumList = sorted(enumerate(sim_scores), key=lambda item: -item[1])

        with open("sx_result510_bm25.txt", "a") as fr:
            for j in range(21):
                if str(int(testDataId[i][0])) == str(int(simNumList[j][0] + 1)):
                    continue
                fr.write('\n')
                fr.write(str(int(testDataId[i][0])))
                fr.write('\t')
                fr.write(str(int(simNumList[j][0] + 1)))


    elapsed = (time.clock() - start)
    print('Time use',  cal_time(elapsed))


猜你喜欢

转载自blog.csdn.net/weixin_40411446/article/details/80505595