(11-7)基于深度学习的实时地图导航:使用模型预测实时路径

10.7  使用模型预测实时路径

文件example.ipynb展示了加载预训练模型以及使用模型进行预测的过程,并将结果保存为GIF动画。通过设置不同的路径来定制数据集和标签目录,并展示了模型的预测结果。

SPLIT = 'val_qualitative_000'
SUBSAMPLE = 5
 
model, data, viz = setup_experiment(cfg)
 
dataset = data.get_split(SPLIT, loader=False)
dataset = torch.utils.data.ConcatDataset(dataset)
dataset = torch.utils.data.Subset(dataset, range(0, len(dataset), SUBSAMPLE))
 
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)
print(len(dataset))
 
from pathlib import Path
 
 
MODEL_URL = 'https://www.cs.utexas.edu/~bzhou/cvt/cvt_nuscenes_vehicles_50k.ckpt'
CHECKPOINT_PATH = '../logs/cvt_nuscenes_vehicles_50k.ckpt'
 
!wget $MODEL_URL -O $CHECKPOINT_PATH
 
if Path(CHECKPOINT_PATH).exists():
    network = load_backbone(CHECKPOINT_PATH)
else:
    network = model.backbone
    print(f'{CHECKPOINT_PATH} not found. Using randomly initialized weights.')
 
import torch
import time
import imageio
import ipywidgets as widgets
 
 
GIF_PATH = './predictions.gif'
 
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
network.to(device)
network.eval()
 
images = list()
 
with torch.no_grad():
    for batch in loader:
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        pred = network(batch)
 
        visualization = np.vstack(viz(batch=batch, pred=pred))
 
        images.append(visualization)
 
 
# Save a gif
duration = [1 for _ in images[:-1]] + [5 for _ in images[-1:]]
imageio.mimsave(GIF_PATH, images, duration=duration)
 
html = f'''
<div align="center">
<img src="{GIF_PATH}?modified={time.time()}" width="80%">
</div>
'''
 
display(widgets.HTML(html))

在上述代码中,使用预训练的模型对地图数据进行预测,并将预测结果保存为GIF动画。这个动画显示了模型预测的车辆和道路信息,构建了预测的实时地图。效果如图10-1所示。

图10-1  根据地图数据预测生成的实时地图

猜你喜欢

转载自blog.csdn.net/asd343442/article/details/143407727