plt.subplots中的ax = ax.flatten()

版权声明:本文为博主原创文章,转载请附上博文链接! https://blog.csdn.net/weixin_38314865/article/details/84785141

在用plt.subplots画多个子图中,ax = ax.flatten()将ax由n*m的Axes组展平成1*nm的Axes组

以下面的例子说明ax = ax.flatten()的作用:

fig, ax = plt.subplots(nrows=2,ncols=2,sharex='all',sharey='all')
ax = ax.flatten()  

for i in range(4):
    img = image[i].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')  # 区别:可以直接用ax[i]

不使用ax = ax.flatten()

fig, ax = plt.subplots(nrows=2,ncols=2,sharex='all',sharey='all') 

for i in range(4):
    img = image[i].reshape(28, 28)

    axs[0, 0].imshow(img, cmap='Greys', interpolation='nearest')  # 区别:不能直接使用ax[i]

    axs[0, 1].imshow(img, cmap='Greys', interpolation='nearest')

    axs[1, 0].imshow(img, cmap='Greys', interpolation='nearest')

    axs[1, 1].imshow(img, cmap='Greys', interpolation='nearest')

猜你喜欢

转载自blog.csdn.net/weixin_38314865/article/details/84785141