Gated-Transformer-on-MTS: 基于PyTorch的改良Transformer模型用于多维时间序列分类
项目概述
该项目提出了一种改良的Transformer架构(Gated Transformer)用于处理多维时间序列(MTS)分类任务。通过双塔结构和门控机制,模型能够同时捕捉时间步(step-wise)和通道(channel-wise)之间的关系,并在多个基准数据集上取得了优于传统CNN基线(FCN和ResNet)的性能。
核心创新点
-
双塔结构:同时计算step-wise和channel-wise的注意力机制
- Step-wise塔:保留传统Transformer的位置编码和mask机制
- Channel-wise塔:去除位置编码和mask,专注于通道间关系
-
门控机制:自适应融合双塔输出
h = W · Concat(C, S) + b \\ g1, g2 = Softmax(h) \\ y = Concat(C · g1, S · g2)
- 仅使用Encoder:针对分类任务简化模型结构
实验结果
在13个多元时间序列数据集上的分类准确率对比:
数据集 | FCN | ResNet | Gated Transformer |
---|---|---|---|
ArabicDigits | 99.4 | 99.6 | 98.8 |
AUSLAN | 97.5 | 97.4 | 97.5 |
CharacterTrajectories | 99.0 | 99.0 | 97.0 |
CMUsubject16 | 100 | 99.7 | 100 |
ECG | 87.2 | 86.7 | 91.0 |
JapaneseVowels | 99.3 | 99.2 | 98.7 |
Libras | 96.4 | 95.4 | 88.9 |
UWave | 93.4 | 92.6 | 91.0 |
KickvsPunch | 54.0 | 51.0 | 90.0 |
NetFlow | 89.1 | 62.7 | 100 |
PEMS | - | - | 93.6 |
Wafer | 98.2 | 98.9 | 99.1 |
WalkvsRun | 100 | 100 | 100 |
技术实现细节
数据预处理
- 处理不等长时间序列:使用零填充至最大时间步长
- 创建PyTorch Dataset和DataLoader对象
- 特殊处理NetFlow数据集的标签(1和13→0和1)
模型架构
- 输入处理:线性层将原始输入映射到d_model维
- 双塔注意力:
- Step-wise塔:带位置编码和mask的多头注意力
- Channel-wise塔:无位置编码的多头注意力
- 门控融合:学习自适应权重融合双塔特征
- 分类头:全连接层输出分类结果
超参数配置
{
"d_model": 512, # 模型维度
"d_hidden": 2048, # FFN隐藏层维度
"q": 64, # Query维度
"v": 64, # Value维度
"h": 8, # 注意力头数
"N": 6, # Encoder层数
"dropout": 0.1, # Dropout率
"EPOCH": 100, # 训练轮数
"BATCH_SIZE": 32, # 批大小
"LR": 1e-4, # 学习率
"optimizer": "Adam" # 优化器
}
项目结构
Gated-Transformer-on-MTS/
├── dataset_process.py # 数据集处理
├── module/ # 模型模块
├── utils/ # 工具类
│ ├── random_seed.py # 随机种子设置
│ ├── heatMap.py # 热力图绘制
│ ├── visualization.py # 训练曲线可视化
│ └── TSNE.py # 降维聚类可视化
├── run.py # 训练脚本
├── run_with_saved_model.py # 测试脚本
├── saved_model/ # 模型保存目录
└── result_figure/ # 结果图目录
使用说明
-
环境配置:
- Python 3.7
- PyTorch ≥1.6
- 支持CPU/GPU
-
数据集准备:
- 从百度云下载.mat格式数据集
- 路径: https://pan.baidu.com/s/1u2HN6tfygcQvzuEK5XBa2A (提取码: dxq6)
-
训练模型:
python run.py --dataset ECG --d_model 512 --h 8 --N 6
-
测试模型:
python run_with_saved_model.py --dataset ECG --model_path saved_model/ECG_best_model.pkl
可视化工具
项目提供了多种可视化工具:
扫描二维码关注公众号,回复:
17617920 查看本文章

- 注意力权重热力图(对比DTW和欧氏距离)
- 训练过程中的loss/accuracy曲线
- t-SNE降维聚类图
- 自定义折线图绘制
注意事项
- 模型保存使用PyTorch 1.6+格式,加载时需兼容版本
- 数据集路径和结果保存路径不建议修改
- 可视化工具中的颜色映射可能需要根据具体数据集调整
该项目为多维时间序列分类任务提供了一种有效的Transformer改良方案,通过创新的双塔结构和门控机制,在多个数据集上展现了优越性能。