matplotlib.pyplot subplots、plot、xlabel等

一、plt.subplots(nrows, ncols, ...)

import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))

上述代码创建了一个有1行3列axes的figure,figure的大小为(12,6),figure的名字为'train'。如下图所示。此时plt指向最右边的ax(因为是最后创建的)。

 上述代码等价于:(和上面一样,此时plt指向最右边的ax)。

import matplotlib.pyplot as plt
plt.figure("train", (12, 6))
plt.subplot(1,3,1)
plt.subplot(1,3,2)
plt.subplot(1,3,3)

 二、plt当前所指的fig/ax永远是最新创建的fig/ax,在调用plt.xxx函数时,要注意操作的对象是哪一个fig的哪个ax。(但plt.show会显示所有figure)

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(0)
epochs = 4
epoch_loss_values = np.random.randint(5, size=epochs)

fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
axes[0].plot(x, y)  # ax也有plot方法
axes[0].set_xlabel('aaa')  # ax有set_xlabel方法,没有xlabel方法
plt.xlabel("epoch")
plt.title("Epoch Average Loss")

结果如下:

三、一个fig中新创建的ax可能会覆盖旧的ax

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(0)
epochs = 4
epoch_loss_values = np.random.randint(5, size=epochs)

fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
axes[0].plot(x, y)
axes[0].set_xlabel('aaa')
plt.subplot(1,2,2)
plt.xlabel("epoch")
plt.title("Epoch Average Loss")

结果如下:

猜你喜欢

转载自blog.csdn.net/qq_41021141/article/details/125973412