画图之前首先要设置画布(figure)对象,使得后面的图形输出在这块规定了大小的画布上,其中参数figsize设置画布大小。
# 得到画布对象
plt.figure(figsize=(width, height)) # unit is inch(英寸)
# 绘制子图,其中index是从1开始计算
plt.subplot(nrows, ncols, index, **kwargs) # 将画布分为nrows*ncols个子区域, index表示第N个子区域
设置坐标轴的起始和终止值
plt.xlim(0, 30) # x in [0, 30]
plt.ylim(0, 100) # y in [0, 100]
# 显示单张黑白图片
def show_single_image(img_arr):
plt.imshow(img_arr, cmap="binary")
plt.show()
show_single_image(x_train[0])
# 显示多张黑白图片
def show_imgs(n_rows, n_cols, x_data, y_data, class_names):
assert len(x_data) == len(y_data)
assert n_rows * n_cols <= len(x_data)
plt.figure(figsize = (n_cols * 1.4, n_rows * 1.6))
for row in range(n_rows):
for col in range(n_cols):
index = n_cols * row + col
plt.subplot(n_rows, n_cols, index+1)
plt.imshow(x_data[index], cmap="binary",
interpolation = 'nearest')
plt.axis('off')
plt.title(class_names[y_data[index]])
plt.show()
class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress',
'Coat', 'Sandal', 'Shirt', 'Sneaker',
'Bag', 'Ankle boot']
show_imgs(3, 5, x_train, y_train, class_names)