基于PyTorch的改良Transformer模型用于多维时间序列分类

Gated-Transformer-on-MTS: 基于PyTorch的改良Transformer模型用于多维时间序列分类

项目概述

在这里插入图片描述

该项目提出了一种改良的Transformer架构(Gated Transformer)用于处理多维时间序列(MTS)分类任务。通过双塔结构和门控机制,模型能够同时捕捉时间步(step-wise)和通道(channel-wise)之间的关系,并在多个基准数据集上取得了优于传统CNN基线(FCN和ResNet)的性能。

核心创新点

  1. 双塔结构:同时计算step-wise和channel-wise的注意力机制

    • Step-wise塔:保留传统Transformer的位置编码和mask机制
    • Channel-wise塔:去除位置编码和mask,专注于通道间关系
  2. 门控机制:自适应融合双塔输出

    h = W · Concat(C, S) + b \\
    g1, g2 = Softmax(h) \\
    y = Concat(C · g1, S · g2)
    

在这里插入图片描述

  1. 仅使用Encoder:针对分类任务简化模型结构

实验结果

在13个多元时间序列数据集上的分类准确率对比:

数据集 FCN ResNet Gated Transformer
ArabicDigits 99.4 99.6 98.8
AUSLAN 97.5 97.4 97.5
CharacterTrajectories 99.0 99.0 97.0
CMUsubject16 100 99.7 100
ECG 87.2 86.7 91.0
JapaneseVowels 99.3 99.2 98.7
Libras 96.4 95.4 88.9
UWave 93.4 92.6 91.0
KickvsPunch 54.0 51.0 90.0
NetFlow 89.1 62.7 100
PEMS - - 93.6
Wafer 98.2 98.9 99.1
WalkvsRun 100 100 100

技术实现细节

数据预处理

  • 处理不等长时间序列:使用零填充至最大时间步长
  • 创建PyTorch Dataset和DataLoader对象
  • 特殊处理NetFlow数据集的标签(1和13→0和1)
    在这里插入图片描述

模型架构

  • 输入处理:线性层将原始输入映射到d_model维
  • 双塔注意力
    • Step-wise塔:带位置编码和mask的多头注意力
    • Channel-wise塔:无位置编码的多头注意力
  • 门控融合:学习自适应权重融合双塔特征
  • 分类头:全连接层输出分类结果
    在这里插入图片描述

超参数配置

{
    
    
    "d_model": 512,          # 模型维度
    "d_hidden": 2048,        # FFN隐藏层维度
    "q": 64,                 # Query维度
    "v": 64,                 # Value维度
    "h": 8,                  # 注意力头数
    "N": 6,                  # Encoder层数
    "dropout": 0.1,          # Dropout率
    "EPOCH": 100,            # 训练轮数
    "BATCH_SIZE": 32,        # 批大小
    "LR": 1e-4,              # 学习率
    "optimizer": "Adam"      # 优化器
}

项目结构

Gated-Transformer-on-MTS/
├── dataset_process.py       # 数据集处理
├── module/                  # 模型模块
├── utils/                   # 工具类
│   ├── random_seed.py       # 随机种子设置
│   ├── heatMap.py           # 热力图绘制
│   ├── visualization.py     # 训练曲线可视化
│   └── TSNE.py              # 降维聚类可视化
├── run.py                   # 训练脚本
├── run_with_saved_model.py  # 测试脚本
├── saved_model/             # 模型保存目录
└── result_figure/           # 结果图目录

使用说明

  1. 环境配置

    • Python 3.7
    • PyTorch ≥1.6
    • 支持CPU/GPU
  2. 数据集准备

    • 从百度云下载.mat格式数据集
    • 路径: https://pan.baidu.com/s/1u2HN6tfygcQvzuEK5XBa2A (提取码: dxq6)
  3. 训练模型

    python run.py --dataset ECG --d_model 512 --h 8 --N 6
    
  4. 测试模型

    python run_with_saved_model.py --dataset ECG --model_path saved_model/ECG_best_model.pkl
    

可视化工具

项目提供了多种可视化工具:

扫描二维码关注公众号,回复: 17617920 查看本文章
  • 注意力权重热力图(对比DTW和欧氏距离)
  • 训练过程中的loss/accuracy曲线
  • t-SNE降维聚类图
  • 自定义折线图绘制

注意事项

  1. 模型保存使用PyTorch 1.6+格式,加载时需兼容版本
  2. 数据集路径和结果保存路径不建议修改
  3. 可视化工具中的颜色映射可能需要根据具体数据集调整

该项目为多维时间序列分类任务提供了一种有效的Transformer改良方案,通过创新的双塔结构和门控机制,在多个数据集上展现了优越性能。

猜你喜欢

转载自blog.csdn.net/QQ_1309399183/article/details/146563811
今日推荐