10.7 使用模型预测实时路径
文件example.ipynb展示了加载预训练模型以及使用模型进行预测的过程,并将结果保存为GIF动画。通过设置不同的路径来定制数据集和标签目录,并展示了模型的预测结果。
SPLIT = 'val_qualitative_000'
SUBSAMPLE = 5
model, data, viz = setup_experiment(cfg)
dataset = data.get_split(SPLIT, loader=False)
dataset = torch.utils.data.ConcatDataset(dataset)
dataset = torch.utils.data.Subset(dataset, range(0, len(dataset), SUBSAMPLE))
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)
print(len(dataset))
from pathlib import Path
MODEL_URL = 'https://www.cs.utexas.edu/~bzhou/cvt/cvt_nuscenes_vehicles_50k.ckpt'
CHECKPOINT_PATH = '../logs/cvt_nuscenes_vehicles_50k.ckpt'
!wget $MODEL_URL -O $CHECKPOINT_PATH
if Path(CHECKPOINT_PATH).exists():
network = load_backbone(CHECKPOINT_PATH)
else:
network = model.backbone
print(f'{CHECKPOINT_PATH} not found. Using randomly initialized weights.')
import torch
import time
import imageio
import ipywidgets as widgets
GIF_PATH = './predictions.gif'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
network.to(device)
network.eval()
images = list()
with torch.no_grad():
for batch in loader:
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
pred = network(batch)
visualization = np.vstack(viz(batch=batch, pred=pred))
images.append(visualization)
# Save a gif
duration = [1 for _ in images[:-1]] + [5 for _ in images[-1:]]
imageio.mimsave(GIF_PATH, images, duration=duration)
html = f'''
<div align="center">
<img src="{GIF_PATH}?modified={time.time()}" width="80%">
</div>
'''
display(widgets.HTML(html))
在上述代码中,使用预训练的模型对地图数据进行预测,并将预测结果保存为GIF动画。这个动画显示了模型预测的车辆和道路信息,构建了预测的实时地图。效果如图10-1所示。
图10-1 根据地图数据预测生成的实时地图