DeepChem教程24: 模型可解释性介绍

前面的几节,你已经学习了如何用DeepChem 训练模型解决不同的问题。但是我们还没有真正的学习模型的可解释性问题。

建模时我们通常要问一些问题模型工作得好不好?我们为什么要相信模型?我作为一个数据科学家的回答是,因为我们有明显的证据证实模型对于手头的测试集是切合实际的“。但通常这不足于说服领域专家。

LIME 是一个能帮助你解决这一问题的工具。它用局部的特征空间扰动来确定特征的重要性。本教程,你将学习如何使用LIMEDeepChem来解释我们的模型学习到了什么。

如果这个工具以人类可理解的方式来处理图像那它可以处理分子吗?本教程我们将学习如何用LIME解释固定长度的特征化模型。

创建模型

我们加载ECFP特征化的Tox21数据集。回顾一下特征化是如何工作的。它识别分子中的小片断,然后设置输出向为1以表示分子中某个片断的存在。

In [1]:

import deepchem as dc

n_features = 1024

tasks, datasets, transformers = dc.molnet.load_tox21(featurization='ecfp')

train_dataset, valid_dataset, test_dataset = datasets

我们现在用这个数据集来训练模型。如前面的教程,我们用MultitaskClassifier,它是多个全链接的简单堆叠。

In [2]:

n_tasks = len(tasks)

n_features = train_dataset.get_data_shape()[0]

model = dc.models.MultitaskClassifier(n_tasks, n_features)

model.fit(train_dataset, nb_epoch=50)

Out[2]:

0.1333492088317871

我们用训练集和测试集评估模型以理解它的准确度。我们用ROC-AUC作为量度。

In [3]:

import numpy as np

metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean)

print("Train scores")

print(model.evaluate(train_dataset, [metric], transformers))

print("Validation scores")

print(model.evaluate(valid_dataset, [metric], transformers))

Train scores

{'mean-roc_auc_score': 0.9911206354520975}

Validation scores

{'mean-roc_auc_score': 0.699686047497269}

使用LIME

模型看来对于预测那个分子有毒性很合理,但是到底如何工作?当它预没分子有毒或无毒时,是分子的哪个方面导致它那样预测的呢?这就是explainability的关键:学习为什以输入会导致一个预测。

LIME是解决这个问题的工具。它是"Local Interpretable Model-Agnostic Explanations"的缩写。它可以对固定大小输入向量解决任何问题。它为不同的特征计算概率分布以及特征间的协方差。这允许它在样本的近邻构建局部线性模型,描述输入的哪个局部对输出影响最大。即增加或去除这个片断将改变有毒还是无毒的预测。

首先我们需要安装LIME。幸运的是,LIME可以方便的从pip获得。所以你可以安装它。

In [ ]:

!pip install lime

现在我们已安装LIME,我们想要为LIME创建Explainer对象。这个对象会输入训练集并为特征命名。我们使用圆形指纹作为我们的特征。我们不想要给我们的特征取自然名称,所以我们给它们数字。另一方面,我们有自然名字作为标签。记得Tox21是用于毒性测量的,所以我们设0 为无毒1为有毒。

In [4]:

from lime import lime_tabular

feature_names = ["fp_%s"  % x for x in range(1024)]

explainer = lime_tabular.LimeTabularExplainer(train_dataset.X,

                                              feature_names=feature_names,

                                              categorical_features=feature_names,

                                              class_names=['not toxic', 'toxic'],

                                              discretize_continuous=True)

 

 

我们试图解释为什么模型为NR-AR分子预测为有毒。具体的测定细节见这里。 here.

In [5]:

# We need a function which takes a 2d numpy array (samples, features) and returns predictions (samples,)

def eval_model(my_model):

    def eval_closure(x):

        ds = dc.data.NumpyDataset(x, n_tasks=12)

        # The 0th task is NR-AR

        predictions = my_model.predict(ds)[:,0]

        return predictions

    return eval_closure

model_fn = eval_model(model)

对于具体的分子我们试图用这个评估函数。我们拿测试集中被准确地预测为有毒的分子来看一下。(即是分子是有毒而且模型也预测它有毒)。

In [6]:

from rdkit import Chem

active_id = np.where((test_dataset.y[:,0] == 1) * (model.predict(test_dataset)[:,0,1] > 0.8))[0][0]

Chem.MolFromSmiles(test_dataset.ids[active_id])

Out[6]:

现在我们有一个模型和一个分子。我们让Explainer指出为什么分子被预测为有毒。我们让它列出100个对预测最敏感觉的特征。(即,指纹中的元素,每个元素与一个或多个片断对应)。

In [7]:

exp = explainer.explain_instance(test_dataset.X[active_id], model_fn, num_features=100, top_labels=1)

返回的值是Explanation对象。你可以调用它的方法来追踪返回值到不同的形式。处理相互作用的一个便利的形式是show_in_notebook(),提供了图形化的表示。

In [8]:

exp.show_in_notebook(show_table=True, show_all=False)

这个输出需要解释。左边表示分子预测为有毒。这个我们已知道的。这就是为什么我们选它。右边它列出了100个对预测最有影响的指纹元素。对于每一个,值一列提示相应的片断是否存在(1.00表示存在,0.00表示不存在)于分子。中间表示每个索引对预测为有毒(蓝色)还是无毒(橙色)作贡献。

大部分片断都不存在。它告诉我们,如果片断存在,将会移位预测。我们对此不太感兴趣。我们想要知道分子中存在的片断对预测的贡献。让我们将些结果变为更有用的形式。

开始,指纹中的索引并不很有用。我们写一些函数来转换特征,映射片断到激活它的索引。

In [9]:

def fp_mol(mol, fp_length=1024):

    """

    returns: dict of <int:list of string>

        dictionary mapping fingerprint index

        to list of SMILES strings that activated that fingerprint

    """

    d = {}

    feat = dc.feat.CircularFingerprint(sparse=True, smiles=True, size=1024)

    retval = feat._featurize(mol)

    for k, v in retval.items():

        index = k % fp_length

        if index not in d:

            d[index] = set()

        d[index].add(v['smiles'])

    return d

 

# What fragments activated what fingerprints in our active molecule?

my_fragments = fp_mol(Chem.MolFromSmiles(test_dataset.ids[active_id]))

我们想要查询 Explanation看哪些片断对于预测有贡献。我们用as_map()方法来获得更适用于处理的信息。

In [10]:

print(exp.as_map())

{1: [(907, -0.23405879109145938), (261, -0.22799151209374974), (257, -0.2127115416006204), (411, -0.2032938542566075), (445, -0.201101199543193), (999, -0.19683277633182114), (505, -0.17598335551311955), (845, -0.16124562050855068), (306, -0.15779431345857292), (326, -0.15729134284912463), (742, -0.15426792127439848), (774, -0.1541352665863784), (648, -0.15240513095212335), (282, -0.15075378351457727), (918, -0.147036129283227), (531, -0.1458139691488669), (279, -0.14390785978173085), (269, -0.13989282701617642), (37, -0.1369273010593831), (530, -0.13566064574462358), (827, -0.1336099559901393), (28, -0.12819498508086055), (889, -0.12482816439354927), (84, 0.123345144700625), (712, -0.12260023102545663), (529, 0.12194683881762106), (513, -0.12144767300189488), (830, -0.11926958219652685), (111, -0.11793890523628446), (434, -0.11598961154307276), (247, 0.11346755135862246), (296, -0.11315257272809631), (394, -0.11054396729966792), (1022, -0.10845154388715085), (850, 0.10819488336102767), (92, -0.10725270764168865), (788, -0.10693252879326674), (565, -0.10619572780884631), (901, -0.10597769712341058), (854, -0.10261187607809283), (632, -0.10165075780263935), (381, -0.10083233541195123), (717, -0.10024949898626785), (431, 0.09886188592649868), (1003, -0.09854835359816157), (646, -0.09821601927382569), (312, 0.09718167861314402), (539, -0.09639497333637208), (693, -0.0960269720546286), (822, 0.09584637471513191), (1005, -0.09441597700147854), (584, -0.09422611177476213), (405, 0.09371804599009508), (594, -0.09361942073302025), (519, 0.09315063287813262), (613, -0.0920180464426831), (151, -0.09125548867464624), (995, 0.09122957856534511), (555, 0.09105473925802852), (619, 0.09045652379413677), (372, 0.09008810465661844), (617, 0.08854326235599133), (517, 0.0876472124639829), (409, -0.08722349514303968), (744, 0.08646480736070905), (470, -0.0861786962874964), (930, 0.08444082349628013), (493, 0.08389172822676175), (429, -0.08368146493327351), (135, 0.08346782897055312), (27, -0.08332078333604556), (923, 0.0827630767476166), (977, -0.0803740639477386), (174, -0.07985778475171695), (204, 0.07748814547746291), (459, 0.07722411480464215), (377, 0.07544148127726504), (274, -0.07528620889379731), (665, -0.07517229225403155), (321, 0.07387303741377259), (733, 0.07313092778231371), (538, -0.07260889806354165), (760, -0.07216344039899467), (751, -0.07086876393200622), (523, 0.07067337687463775), (467, 0.06911819793695931), (172, 0.06779708514374157), (131, 0.06747370559195916), (732, 0.06727167331565105), (344, 0.06123528076874165), (155, -0.06080053839396983), (384, 0.05715795555565539), (614, 0.053775300985781746), (900, -0.050104647498526), (52, 0.047430623988994364), (460, 0.045920906298128734), (800, 0.04169171427687711), (316, 0.0391090952059404), (388, 0.0334174485302755), (752, -0.01827203770682766)]}

这个图的关键是标签,我们只有一个。这个值是元组的列表,形式为(fingerprint_index, weight)。我们将它转换到字典,映射索引到权重。

In [11]:

fragment_weight = dict(exp.as_map()[1])

我们知道那个片断存在于我们感兴趣的分子中(my_fragments),我们也知道那个片断对预测有贡献。我们循环它们并输出它们。

In [12]:

for index in my_fragments:

    if index in fragment_weight:

        print(index, my_fragments[index], fragment_weight[index])

555 {'C[C@](C)(C)CCC'} 0.09105473925802852

84 {'C=CC'} 0.123345144700625

519 {'C[C@@H](C)C'} 0.09315063287813262

274 {'C[C@@H](C)C(C=C)[C@@H](C)C'} -0.07528620889379731

529 {'CCC[C@H](C)C'} 0.12194683881762106

下载全文请到www.data-vision.net,技术联系电话13712566524

 

猜你喜欢

转载自blog.csdn.net/lishaoan77/article/details/114376718