车流量监测

车流量监测

项目展示

在这里插入图片描述

import torch
import cv2
import os
from ultralytics import YOLO

# 设置环境变量
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

# 定义文件路径
project_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(project_dir, 'models', 'model.pth')
video_path = os.path.join(project_dir, 'traffic.mp4')
output_video_path = os.path.join(project_dir, 'out.avi')

# 检查文件是否存在
if not os.path.exists(model_path):
    print(f"文件 {
      
      model_path} 不存在,请检查路径是否正确。")
else:
    try:
        # 使用 torch.load 加载模型
        ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
    except Exception as e:
        print(f"加载模型时发生错误:{
      
      e}")

# 加载 YOLO 模型
model = YOLO("yolov8s.pt")
names = model.model.names

# 打开视频文件
cap = cv2.VideoCapture(video_path)
assert cap.isOpened(), "Error reading video file"

# 获取视频属性
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))

# 创建视频写入对象
video_writer = cv2.VideoWriter(output_video_path,
                               cv2.VideoWriter_fourcc(*'mp4v'),
                               fps,
                               (w, h))

# 定义临界值
boundary_line_y = 750  # 假设临界值在 y=400 的位置

# 初始化计数器
counters = {
    
    }
crossed_ids = set()

# 处理视频帧
while cap.isOpened():
    success, im0 = cap.read()
    if not success:
        break

    # 进行目标跟踪
    results = model.track(im0, persist=True, show=False)

    # 绘制临界值线
    cv2.line(im0, (0, boundary_line_y), (w, boundary_line_y), (0, 255, 0), 2)

    # 绘制跟踪框
    for r in results:
        boxes = r.boxes
        for box in boxes:
            b = box.xyxy[0]  # get box coordinates in (top, left, bottom, right) format
            c = box.cls
            t = box.id
            cv2.rectangle(im0, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (255, 0, 0), 2)

            # 计数逻辑
            class_name = names[int(c)]
            center_y = (int(b[1]) + int(b[3])) / 2
            if t not in crossed_ids and center_y > boundary_line_y:
                if class_name not in counters:
                    counters[class_name] = 0
                counters[class_name] += 1
                crossed_ids.add(t)
            
            # 在框上显示跟踪ID
            cv2.putText(im0, f"ID: {
      
      int(t)}", (int(b[0]), int(b[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)

    # 显示计数结果
    for class_name, count in counters.items():
        cv2.putText(im0, f"{
      
      class_name}: {
      
      count}", (10, h - 30 - 30 * list(counters.keys()).index(class_name)), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)

    # 显示图像
    cv2.imshow('Frame', im0)
    if cv2.waitKey(1) & 0xFF == ord('q'):  # 按 q 键退出
        break

    # 写入视频帧
    video_writer.write(im0)

# 释放资源
cap.release()
video_writer.release()
cv2.destroyAllWindows()

猜你喜欢

转载自blog.csdn.net/joker_man1/article/details/143351681