PyTorch | 深度学习库能做什么?

PyTorch | 深度学习库能做什么?

1. GPU 加速

import torch
import time
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.backends.cudnn.enabled)

a = torch.randn(10000,1000)
b = torch.randn(1000,2000)
t0 = time.time()
c = torch.matmul(a,b)
t1 = time.time()
print(a.device,t1-t0,c.norm(2))

device = torch.device('cuda')
a = a.to(device)
b = b.to(device)

t0 = time.time()
c = torch.matmul(a,b)
t1 = time.time()
print(a.device,t1-t0,c.norm(2))

在这里插入图片描述

2. 自动求导

import torch
from torch import autograd

x = torch.tensor(1.)
a = torch.tensor(1.,requires_grad=True)
b = torch.tensor(2.,requires_grad=True)
c = torch.tensor(3.,requires_grad=True)

y = a**2 * x + b * x + c

print('before:',a.grad,b.grad,c.grad)
grads = autograd.grad(y,[a,b,c])
print('after:',grads[0],grads[1],grads[2])

在这里插入图片描述

3. 常用 API

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/m0_52650517/article/details/120019897