深度学习系列50:苹果m1芯片加速pytorch

1. 介绍

Apple的Metal Performance Shaders(MPS)作为PyTorch的后端来加速GPU训练。MPS后端扩展了PyTorch框架,提供了在Mac上设置和运行操作的脚本和功能。MPS通过针对每个Metal GPU系列的独特特性进行微调的内核来优化计算性能。新设备将机器学习计算图和原语映射到MPS提供的MPS Graph框架和优化内核上。
目前pytorch加速版本还是preview状态,安装命令如下:

conda install pytorch torchvision torchaudio -c pytorch-nightly

使用方式很简单,to(‘mps:0’)即可转入MPS进行计算。
下面使用m1 pro 16-core gpu进行测试。

2. pytorch测试1

import torch
img = torch.randn(64, 10, 64, 64)
dev = 'mps:0'
img_dev = img.to(dev)
conv = torch.nn.Conv2d(10,10,3).to(dev)
%timeit conv(img_dev)

dev = 'cpu'
conv = torch.nn.Conv2d(10,10,3).to(dev)
%timeit conv(img)

结果:

439 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
10.6 ms ± 43.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

m1的性能是cpu的20倍左右。

3. pytorch测试2

参考
https://github.com/rasbt/machine-learning-notes/tree/main/benchmark/pytorch-m1-gpu

lenet-mnist:

# cpu
Time / epoch without evaluation: 0.19 min
Epoch: 001/001 | Train: 97.32% | Validation: 97.77% | Best Validation (Ep. 001): 97.77%
Time elapsed: 0.27 min
Total Training Time: 0.27 min
Test accuracy 97.41%
Total Time: 0.29 min

# m1
Time / epoch without evaluation: 0.12 min
Epoch: 001/001 | Train: 97.32% | Validation: 97.77% | Best Validation (Ep. 001): 97.77%
Time elapsed: 0.18 min
Total Training Time: 0.18 min
Test accuracy 97.40%
Total Time: 0.20 min

mlp-minst:

# cpu
Epoch: 001/001 | Train: 91.43% | Validation: 93.38% | Best Validation (Ep. 001): 93.38%
Time elapsed: 0.10 min
Total Training Time: 0.10 min
Test accuracy 91.99%
Total Time: 0.12 min

# m1
Time / epoch without evaluation: 0.06 min
Epoch: 001/001 | Train: 91.67% | Validation: 93.42% | Best Validation (Ep. 001): 93.42%
Time elapsed: 0.11 min
Total Training Time: 0.11 min
Test accuracy 92.20%
Total Time: 0.13 min

vgg16-cifar10

# cpu
Epoch: 001/001 | Batch 0000/1406 | Loss: 2.5735
2022-08-29 18:48:19
Epoch: 001/001 | Batch 0100/1406 | Loss: 2.2132
2022-08-29 19:01:51
Epoch: 001/001 | Batch 0200/1406 | Loss: 2.0938
2022-08-29 19:15:27
Epoch: 001/001 | Batch 0300/1406 | Loss: 2.0561
2022-08-29 19:29:19

# m1
Epoch: 001/001 | Batch 0000/1406 | Loss: 2.6674
2022-08-29 17:58:05
Epoch: 001/001 | Batch 0100/1406 | Loss: 3.2263
2022-08-29 18:00:52
Epoch: 001/001 | Batch 0200/1406 | Loss: 2.2019
2022-08-29 18:03:29
Epoch: 001/001 | Batch 0300/1406 | Loss: 2.2948
2022-08-29 18:06:07

300个batch,cpu用时41min,mps用时8min

猜你喜欢

转载自blog.csdn.net/kittyzc/article/details/126482903