sklearn计算混淆矩阵

# coding: utf-8
import sklearn.metrics as sm
import pandas as pd
def error_analysis(src_file, pred_file, tgt):
    '''
        当某句话存在错误标签时,输出当前句子
    '''
    f = open(src_file, 'r', encoding='utf-8')
    g = open(pred_file, 'r', encoding='utf-8')
    h = open(tgt, 'w', encoding='utf-8')
    src_lines = f.readlines()
    pred_lines = g.readlines()
    count = 0
    label_list = []
    word_list = []
    p_label_list = []
    for i, line in enumerate(src_lines):
        if line != '\n':
            line = line.strip('\ufeff\n').split()
            p_line = pred_lines[i].strip('\ufeff\n').split()
            word = line[0]
            label = line[1]
            word_list.append(word)
            label_list.append(label)
            p_label_list.append(p_line[1])
        else:
            # if count == 10:
            #     break
            if label_list == p_label_list:
                # print('label_list:',label_list)
                # print('p_label_list:', p_label_list)
                count += 1
            else:
                for w, l, p in zip(word_list, label_list, p_label_list):
                    if len(l) == 1:
                        h.write(w + '\t\t' + l + '\t\t\t' + p + '\n')
                    else:
                        h.write(w + '\t\t' + l + '\t' + p + '\n')
                h.write('\n')
            label_list = []
            word_list = []
            p_label_list = []
    print('正确的有%d句话.'%count)
    f.close()
    g.close()
    h.close()

def get_list(real, pred):
    f = open(real, 'r', encoding='utf-8')
    g = open(pred, 'r', encoding='utf-8')
    src_lines = f.readlines()
    pred_lines = g.readlines()
    count = 0
    real_list = []
    pred_list = []
    for i, line in enumerate(src_lines):
        if line != '\n':
            line = line.strip('\ufeff\n').split()
            real_list.append(line[1])
            p_line = pred_lines[i].strip('\ufeff\n').split()
            pred_list.append(p_line[1])
        else:
            count += 1
    print('count=%d'%count)

    return real_list, pred_list



if __name__ == '__main__':
    pred_file = 'pred_test.txt'
    src_file  = 'real_test.txt'
    tgt_file  = 'error_analysis.txt'

    label_list = ['O','B-PER.NAM', 'I-PER.NAM', 'E-PER.NAM', 'S-PER.NAM', 'B-PER.NOM', 'I-PER.NOM', 'E-PER.NOM', 'S-PER.NOM',
                  'B-GPE.NAM', 'I-GPE.NAM', 'E-GPE.NAM', 'S-GPE.NAM', 'B-GPE.NOM', 'E-GPE.NOM',
                  'B-LOC.NAM', 'I-LOC.NAM', 'E-LOC.NAM', 'B-LOC.NOM', 'I-LOC.NOM', 'E-LOC.NOM', 'S-LOC.NOM',
                  'B-ORG.NAM', 'I-ORG.NAM', 'E-ORG.NAM', 'B-ORG.NOM', 'I-ORG.NOM', 'E-ORG.NOM' ]
    label_dict = {
    
    0: 'O', 1: 'B-PER.NAM', 2: 'I-PER.NAM', 3: 'E-PER.NAM', 4: 'S-PER.NAM', 5: 'B-PER.NOM', 6: 'I-PER.NOM', 7: 'E-PER.NOM', 8: 'S-PER.NOM', 9: 'B-GPE.NAM', 10: 'I-GPE.NAM', 11: 'E-GPE.NAM', 12: 'S-GPE.NAM', 13: 'B-GPE.NOM', 14: 'E-GPE.NOM', 15: 'B-LOC.NAM', 16: 'I-LOC.NAM', 17: 'E-LOC.NAM', 18: 'B-LOC.NOM', 19: 'I-LOC.NOM', 20: 'E-LOC.NOM', 21: 'S-LOC.NOM', 22: 'B-ORG.NAM', 23: 'I-ORG.NAM', 24: 'E-ORG.NAM', 25: 'B-ORG.NOM', 26: 'I-ORG.NOM', 27: 'E-ORG.NOM'}


    # print(len(label_list))
    # error_analysis(src_file, pred_file, tgt_file)
    real_list, pred_list = get_list(src_file, pred_file)

    # 生产混淆矩阵
    # confusion_matrix = sm.confusion_matrix(real_list, pred_list, labels=label_list)

    # 保存为excel
    # df = pd.DataFrame(confusion_matrix)
    # file = 'confusion_matrix.xlsx'
    # # df.columns = label_list
    # df.rename(index=label_dict,columns=label_dict, inplace=True)
    # df.to_excel(file, index=True)
    cls_report = sm.classification_report(real_list, pred_list, labels=label_list)
    print(cls_report)




猜你喜欢

转载自blog.csdn.net/tailonh/article/details/112252474