matplot画图之plt.scatter()函数

 函数原型:

def scatter(
        x, y, s=None, c=None, marker=None, cmap=None, norm=None,
        vmin=None, vmax=None, alpha=None, linewidths=None, *,
        edgecolors=None, plotnonfinite=False, data=None, **kwargs)

:
    __ret = gca().scatter(
        x, y, s=s, c=c, marker=marker, cmap=cmap, norm=norm,
        vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths,
        edgecolors=edgecolors, plotnonfinite=plotnonfinite,
        **({"data": data} if data is not None else {}), **kwargs)
    sci(__ret)
    return __ret

x,y:表示的是大小为(x,y)的数组,绘制散点图的数据点

s:是一个实数或者是一个数组大小为(n,),这个是一个可选的参数。

c:表示的是颜色,默认是蓝色'b',表示的是标记的颜色

marker:表示的是绘制标记的样式,默认的是'o'圆圈,改成'x'则变成字符X。

cmap:Colormap实体或者colormap的名字,cmap当c是一个浮点数数组的时候才使用。如果没有申明就是image.cmap

norm:Normalize实体来将数据亮度转化到0-1之间,只有c是一个浮点数的数组的时候才使用。如果没有申明,就是默认为colors.Normalize。

vmin,vmax:实数,当norm存在的时候忽略。用来进行亮度数据的归一化。

alpha:实数,0-1之间。

linewidths:也就是标记点的长度。

下图是根据csv文件导入的数据所绘制的图像。

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# For data preprocess
import numpy as np
import csv
import os

# For plotting
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
#下面三个包是新增的
from sklearn.model_selection import train_test_split
import pandas as pd
import pprint as pp
pd.set_option('display.max_rows', 200) # 200行
pd.set_option('display.max_columns', 200) # 200列
myseed = 42069  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)
tr_path = 'covid.train.csv'  # path to training data
tt_path = 'covid.test.csv'   # path to testing data
data_tr = pd.read_csv(tr_path) #读取训练数据
data_tt = pd.read_csv(tt_path) #读取测试数据
#print(data_tt.head(3))
#print(data_tr.head(3))
#print(data_tr.columns) #查看有多少列特征
data_tr.drop(['id'],axis = 1, inplace = True) #由于id列用不到,删除id列
data_tt.drop(['id'],axis = 1, inplace = True)
cols = list(data_tr.columns)  #拿到特征列名称
#pp.pprint(data_tr.columns)
#pp.pprint(data_tr.info()) #看每列数据类型和大小
WI_index = cols.index('WI')  # WI列是states one-hot编码最后一列,取值为0或1,后面特征分析时需要把states特征删掉
WI_index #wi列索引
#one-hot编码。one-hot编码的定义是用N位状态寄存器来对N个状态进行编码。
# 比如[0,0.3],(0.3,0.6],(0.6,1],有3个分类值,因此N为3,对应的one-hot编码可以表示为100,010,001。
#使用步骤:比如用LR算法做模型,在数据处理过程中,可以先对连续变量进行离散化处理,
# 然后对离散化后数据进行one-hot编码,最后放入LR模型中。
# 这样可以增强模型的非线性能力。

#print(data_tr.iloc[:, 40:].describe()) #从上面可以看出wi 列后面是cli, 所以列索引从40开始, 并查看这些数据分布
#print(data_tt.iloc[:, 40:].describe()) #查看测试集数据分布,并和训练集数据分布对比,两者特征之间数据分布差异不是很大

plt.scatter(data_tr.loc[:, 'cli'], data_tr.loc[:, 'tested_positive.2']) #肉眼分析cli特征与目标之间相关性
plt.show()

猜你喜欢

转载自blog.csdn.net/qq_41722524/article/details/122630831
今日推荐