在机器阅读理解的论文中,经常可以看到对“文章-问题”可视化的二维热力图,例如下图。在看实验结果的时候用这种图可以直观的看到attention的效果怎么样。比如下图:
于是从github中找到了一个例子,进行了简单的实验。
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.ticker as ticker
a = torch.randn(4, 2)
b = a.softmax(dim=1)
c = a.softmax(dim=0).transpose(0, 1)
print(a, '\n', b, '\n', c)
d = b.matmul(c)
print(d)
d = d.numpy()
得到numpy的4*4数据。然后用matplotlib可视化。
variables = ['A','B','C','X']
labels = ['ID_0','ID_1','ID_2','ID_3']
df = pd.DataFrame(d, columns=variables, index=labels)
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(df, interpolation='nearest', cmap='hot_r')
fig.colorbar(cax)
tick_spacing = 1
ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
ax.yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
ax.set_xticklabels([''] + list(df.columns))
ax.set_yticklabels([''] + list(df.index))
plt.show()
得到下图: