nlp之命名实体识别HMM(2)

#!/usr/bin/python
# -*- coding: utf-8 -*-
# @Time    : 2018/7/26 13:54
# @verion  : python3.6
# @File    : OrgRecognize.py
# @Software: PyCharm
__author__ = 'xiaohu'


class OrgRecognize:
    def __init__(self, hidden_states):
        self.hidden_states = hidden_states

    def load_transition_probability(self, hidden_states):
        '''
        载入状态转移概率矩阵
        :return:字典,格式为:key为首状态,value为字典--key为次状态,value为概率 [ 首状态后面为该次状态的概率]
        '''
        result = {x: {} for x in hidden_states}
        with open('./data/transition_probability.txt', mode='rb') as file:
            all_data = file.readlines()
            for line in all_data:
                split_line = line.strip().split(',')
                result[split_line[0]][split_line[1]] = split_line[2]
        # print(result)
        return result

    def load_initial_vector(self):
        '''
        载入初始化概率向量π
        :return:字典,格式为:key为状态,value为概率
        '''
        result = {}
        with open('./data/initial_vector.txt', mode='rb') as file:
            all_data = file.readlines()
            for line in all_data:
                split_line = line.strip().split(',')
                result[split_line[0]] = split_line[2]
        # print(result)
        return result

    def load_emit_probability(self, hidden_states):
        '''
        载入观测概率矩阵
        :return:字典,格式为:key为隐状态,value为字典--key为显状态,value为概率 [ 该隐状态对应着该显状态的概率]
        '''
        result = {x: {} for x in hidden_states}
        with open('./data/emit_probability.txt', mode='rb') as file:
            all_data = file.readlines()
            for line in all_data:
                split_line = line.strip().split(',')
                result[split_line[0]][split_line[1]] = split_line[2]
        # print(result)
        return result

    def viterbi(self, observation, transition_probability, initial_vector, emit_probability, hidden_states):
        '''
        用维特比算法得到最优化序列
        :param observation:粗分词结果
        :param transition_probability:状态转移概率矩阵
        :param initial_vector:初始化概率向量
        :param emit_probability:观测概率矩阵
        :return:最优化序列
        '''
        result = []
        compute_record = []  # 记录每次计算的结果
        delta_result = {}  # 记录算法中的delta

        # 初始化
        for the_hidden_states in hidden_states:
            if emit_probability[the_hidden_states].has_key(observation[0]):
                delta_result[the_hidden_states] = eval(initial_vector[the_hidden_states]) * \
                                                  eval(emit_probability[the_hidden_states][observation[0]])
            else:
                delta_result[the_hidden_states] = 0
        compute_record.append(delta_result)
        # print(compute_record)

        # 递推,或许词语的计算
        for index, word in enumerate(observation[1:]):
            delta_result = {}  # 记录算法中的delta
            for current_hidden_states in hidden_states:
                if emit_probability[current_hidden_states].has_key(word):
                    # print(index)
                    # print(current_hidden_states)
                    # print(word)
                    # print(emit_probability[current_hidden_states][word])
                    delta_result[current_hidden_states] = max(
                        [compute_record[index][x] * eval(transition_probability[x][current_hidden_states]) *
                         eval(emit_probability[current_hidden_states][word]) for x in hidden_states])
                    # print(delta_result[current_hidden_states])
                else:
                    delta_result[current_hidden_states] = 0
            compute_record.append(delta_result)
        print(compute_record)
        # 结束,返回最大概率序列
        tag_sequence = []  # 最大概率标签序列
        for record in compute_record:
            tag_sequence.append(max(record, key=record.get))
        print(tag_sequence)
        return tag_sequence

    def get_organization(self, sequence, patterns, observation):
        '''
        识别机构名
        :param sequence:最大的标签概率序列
        :param patterns:模式串
        :param observation:单词序列
        :return:list 机构名
        '''
        orgs = []  # 存放机构名
        organization_indices = []  # 存放机构名索引
        str_sequence=''.join(sequence)
        for pattern in patterns:
            if pattern in str_sequence:
                begin_index=str_sequence.index(pattern)
                end_index=begin_index+len(pattern)
                organization_indices.append([begin_index,end_index])
        print(organization_indices)
        if len(organization_indices)!=0:
            for index in organization_indices:
                orgs.append(''.join(observation[index[0]:index[1]]))
        return orgs

    def load_pattern(self):
        '''
        载入机构模式串
        :return:列表:
        '''
        result = []
        with open('./data/nt.pattern.txt', mode='r') as file:
            all_data = file.readlines()
            for line in all_data:
                if len(line) != 0:
                    result.append(line.strip())
        # print(result)
        return result


if __name__ == '__main__':
    sentence = ["始##始", "中海油", "集团", "在", "哪里", "末##末"]
    hidden_states = ["A", "B", "C", "D", "F", "G", "I", "J", "K", "L", "M", "P", "S", "W", "X", "Z"]
    # print(sentence)
    observation = sentence
    orgRecon = OrgRecognize(hidden_states=hidden_states)
    transition_probability = orgRecon.load_transition_probability(hidden_states=hidden_states)
    initial_vector = orgRecon.load_initial_vector()
    emit_probability = orgRecon.load_emit_probability(hidden_states=hidden_states)
    sequence = orgRecon.viterbi(observation=observation, transition_probability=transition_probability,
                                initial_vector=initial_vector, emit_probability=emit_probability,
                                hidden_states=hidden_states)

    patterns = orgRecon.load_pattern()
    orgs = orgRecon.get_organization(sequence=sequence, patterns=patterns, observation=observation)
    if len(orgs) != 0:
        for org in orgs:
            print(org)
    else:
        print('未识别到机构名')
        print(sequence)
    # print(orgs)
    # a = {1: 2, 2: 2, 3: 1, 4: 'aa'}  # 比较字典里面的最大值,会输出最大的键值
    # print(max(a))
    # print(a.get(1))

猜你喜欢

转载自blog.csdn.net/qq_18617299/article/details/81236308
今日推荐