简易了解Pytorch中的@ 和 * 运算符(附Demo)

1. 基本知识

在 PyTorch 中,@ 和 * 运算符用于不同类型的数学运算,具体是矩阵乘法和逐元素乘法

基本知识

运算符 功能 适用场景 示例
@ 矩阵乘法(或点乘) 用于执行线性代数中的矩阵乘法 C = A @ B,其中 A 和 B 为矩阵
* 逐元素乘法 用于对同一形状的张量进行逐元素相乘 C = A * B,其中 A 和 B 为同形状张量

两者的差异总结如下:

特点 矩阵乘法 (@) 逐元素乘法 (*)
运算类型 矩阵乘法(线性代数) 逐元素运算
适用条件 列数等于行数 形状相同
返回结果形状 (m, p) 与输入张量相同
使用示例 C = A @ B C = A * B
适用场景 线性变换、深度学习中的权重计算 图像处理、逐元素操作等
  • 使用 @ 运算符进行矩阵乘法适合线性代数操作,常用于深度学习中的层与权重的运算
  • 使用 * 运算符进行逐元素乘法适合需要对张量进行元素级操作的场景,如数据处理和图像增强等

2. @

@ 运算符用于执行矩阵乘法或向量点乘
对于两个矩阵 A 和 B,其结果 C 是一个新矩阵,其中 C[i][j] 是 A 的第 i 行与 B 的第 j 列的点积

适用条件: A 的列数必须等于 B 的行数,即 A 的形状为 (m, n),B 的形状为 (n, p),则结果 C 的形状为 (m, p)

import torch

# 创建两个矩阵
A = torch.tensor([[1, 2], [3, 4]])  # 2x2 矩阵
B = torch.tensor([[5, 6], [7, 8]])  # 2x2 矩阵

# 使用 @ 运算符进行矩阵乘法
C = A @ B
print("矩阵乘法结果:\n", C)

截图如下:

在这里插入图片描述

3. *

*运算符用于对两个相同形状的张量进行逐元素相乘
结果张量的每个元素是操作数张量中对应元素的乘积

适用条件: A 和 B 必须具有相同的形状(或能够通过广播规则兼容)

import torch

# 创建两个相同形状的张量
A = torch.tensor([[1, 2], [3, 4]])  # 2x2 矩阵
B = torch.tensor([[5, 6], [7, 8]])  # 2x2 矩阵

# 使用 * 运算符进行逐元素乘法
C = A * B
print("逐元素乘法结果:\n", C)

截图如下:

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_47872288/article/details/143212224