1. 函数说明
PyTorch关于torch.meshgrid()
函数的说明:torch.meshgrid
y, x = torch.meshgrid(a, b)
的功能是生成网格,可以用于生成坐标。
一般函数输入两个一维tensor a, b
,返回两个tensor -> y, x
,其中:
y, x
的行数均为a
的元素个数y, x
的列数均为b
的元素个数y
: 记录y
轴坐标x
: 记录x
轴坐标
我们一般都会通过torch.stack((x, y), dim=2)
方法将x, y
拼接在一起
2. 例子
import torch
y, x = torch.meshgrid([torch.arange(4), torch.arange(6)])
print(f"x.shape: {
x.shape}")
print(f"y.shape: {
y.shape}")
print(f"x:\n {
x}")
print(f"y:\n {
y}")
print("\n --------------- \n")
grid = torch.stack((x, y), dim=2)
print(f"grid.shape: {
grid.shape}")
print(f"grid:\n {
grid}")
"""
x.shape: torch.Size([4, 6])
y.shape: torch.Size([4, 6])
x:
tensor([[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5]])
y:
tensor([[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3]])
---------------
grid.shape: torch.Size([4, 6, 2])
grid:
tensor([[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0], [5, 0]],
[[0, 1], [1, 1], [2, 1], [3, 1], [4, 1], [5, 1]],
[[0, 2], [1, 2], [2, 2], [3, 2], [4, 2], [5, 2]],
[[0, 3], [1, 3], [2, 3], [3, 3], [4, 3], [5, 3]]])
"""
生成的tensor对应着每个grid中cell的左上角坐标,很容易理解。