机器学习必备,用matplotlib画2D和3D散点图参数介绍及实例分析

一、简介

在机器学习中,经常需要通过散点图查看原始数据的分布情况,从而对特征和算法的选择进行初步判断。

散点图可以形象展示直角坐标系中两个变量之间的关系。在散点图中 ,每个数据点的位置实际上就是两个变量的值。变量间的任何关系都可以拿散点图来表示。

matplotlib绘图功能模仿MATLAB,非常方便和强大。下面,本文将详细介绍如何使用matplotlib画出好看实用的散点图。

如果你对matplotlib完全不熟悉,可以先花10分钟去我的另一篇博客学习一下基本操作:
10分钟带你从零上手matplotlib数据可视化

需要进一步深入了解的朋友可以查看 matplotlib.pyplot.scatter 官方文档

二、2D散点图参数及实例

1. 常用参数详解

import matplotlib.pyplot as plt
plt.scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, 
			vmin=None, vmax=None, alpha=None, linewidths=None, 
			verts=None, edgecolors=None, *, data=None, **kwargs)
  • x,y传入数组,形如shape(n,) ,表示每个点的横、纵坐标
  • s传入标量或数组,形如shape(n,),表示每个点标记的大小,可选,默认为[‘lines.markersize’] ** 2
  • c传入颜色,或颜色序列,表示点标记的颜色,b=蓝色,g=绿色,r=红色,y=黄色,k=黑色,w=白色,c=蓝绿色,m=洋红
  • marker,传入字符串,表示点标记的样式,默认为’o’,常用的有'^', '*', 'o','+','x',
    更多样式可以查看官方文档
  • linewidths传入标量或数组,形如shape(n,),表示标记的边框线宽,默认为None
  • edgecolors传入颜色或颜色序列,表示标记边框的颜色,默认为’face’,传入‘face’表示与标记颜色相同,传入’none’表示无边框。

2. 最简单的2D散点图实例

import matplotlib.pyplot as plt

x = [1, 2, 3, 4]
y = [1, 2, 3, 4]
plt.scatter(x, y, s=[10, 20, 50, 100], c=['r', 'y', 'g', 'b'])
plt.show()

在这里插入图片描述
从图中可以看出来,的确是s控制了每个点的大小,c控制了每个点的颜色。

3. 机器学习中的2D散点图实例

下面我们先用sklearn中经典的iris分类数据画一个二维散点图

from sklearn import datasets
import matplotlib.pyplot as plt

#从sklearn中获取经典的iris数据
iris = datasets.load_iris() #iris.data为150x4矩阵
x1 = iris.data[:, 1]        #获取第二列特征值
x2 = iris.data[:, 2]        #获取第三列特征值
y = iris.target             #y是分类值:0,1,2

plt.scatter(x1, x2, c=y)    #将y作为参数传给c能够很方便的区分不同类别的颜色
plt.title('Iris Classification')
plt.xlabel('Petal length')
plt.ylabel('Petal width')

plt.show()

在这里插入图片描述

三、3D散点图参数及实例

Matplotlib 绘制 3D 图像主要通过 mpl_toolkits.mplot3d 模块实现,但由于三维图像实际上是在二维画布上展示,因此同样需要载入 pyplot 模块。

备注:mpl_toolkits.mplot3d这个模块不需要另外安装,matplotlib中已自带。

1. 常用参数详解

Axes3D.scatter(xs, ys, zs=0, zdir='z', s=20, c=None, 
			   depthshade=True, *args, **kwargs)
  • xs, ys控制x轴和y轴坐标
  • zs控制z轴坐标,默认为0,如果传入1个标量,那么就是所有点在同一高度,传入数组就是与xs, ys一一对应的高度。
  • s控制点的大小
  • c控制点的颜色

2. 3D散点图实例

还是用上面的数据,我们取iris.data中前三个特征画散点图

from sklearn import datasets
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

#从sklearn中获取经典的iris数据
iris = datasets.load_iris() #iris.data为150x4矩阵
x1 = iris.data[:, 0]        #获取第1列特征值
x2 = iris.data[:, 1]        #获取第2列特征值
x3 = iris.data[:, 2]        #获取第3列特征值
y = iris.target             #y是分类值:0,1,2

fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(x1, x2, x3, c=y)
plt.show()

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_43756456/article/details/86074190
今日推荐