一欧元滤波器:用于缓解YOLOv8-pose关键点抖动

参考文章

一欧元滤波器(OneEuroFilter)

推荐开源项目:1€ Filter(一欧元滤波器)

1.基本原理

一个低通滤波器,设计自适应公式,对历史信息进行加权平均。

2.YOLOv8-pose进行姿态估计

处理关键点脚本,主要需要手臂关键点,用于获取指定关键点并绘制图片:

import cv2
import torch
import numpy as np


COCO_keypoint_indices = {
    0: 'nose',
    1: 'left_eye',
    2: 'right_eye',
    3: 'left_ear',
    4: 'right_ear',
    5: 'left_shoulder',
    6: 'right_shoulder',
    7: 'left_elbow',
    8: 'right_elbow',
    9: 'left_wrist',
    10: 'right_wrist',
    11: 'left_hip',
    12: 'right_hip',
    13: 'left_knee',
    14: 'right_knee',
    15: 'left_ankle',
    16: 'right_ankle'
}

COCO_DEFAULT_UPPER_BODY_KEYPOINT_INDICES = (5, 6, 7, 8, 9, 10)          # 上半身的关键点索引
COCO_DEFAULT_CONNECTIONS = ((4, 2), (2, 0), (0, 1), (1, 3), (3, 5))     # 关键点连接顺序(例如:0连接1,1连接2,依此类推)

default_up_body_indices = {
    0: 'left_shoulder',
    1: 'right_shoulder',
    2: 'left_elbow',
    3: 'right_elbow',
    4: 'left_wrist',
    5: 'right_wrist',
}


def get_upper_body_keypoint(data, keypoint_indices=COCO_DEFAULT_UPPER_BODY_KEYPOINT_INDICES):
    # 检查数据是否为空或大小为零
    if data is None or data.shape[0] == 0 or data.shape[1] == 0:
        # 返回一个空的数组,形状为 (0, len(keypoint_indices), 3)
        return np.empty((0, len(keypoint_indices), 3))

    # 检查 keypoint_indices 是否超出数据的范围
    if max(keypoint_indices) >= data.shape[1]:
        raise IndexError("Keypoint indices are out of bounds for the given data")

    return data[:, keypoint_indices, :]


def image_read(image):
    # 如果 image 是字符串,则尝试读取路径
    if isinstance(image, str):
        img = cv2.imread(image)
        if img is None:
            raise ValueError(f"无法读取图像路径: {image}")
    elif isinstance(image, np.ndarray):
        img = image
    else:
        raise TypeError("image 参数应为字符串路径或 numpy 数组")

    return img


def image_show(image, desc="KeyPoint"):
    cv2.imshow(desc, image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


def plot_keypoint(image, data, connections=COCO_DEFAULT_CONNECTIONS, point_color=(0, 0, 255), point_radius=4,
                  line_color=(0, 255, 0), line_thickness=2):
    """
    在图片上绘制关键点和连线。
    Args:
        image: 图片源
        data: YOLOv8姿态检测结果
        connections: 连线顺讯
        point_color: 关键点的颜色
        point_radius: 关键点的大小
        line_color: 连线的颜色
        line_thickness: 连线的粗细

    Returns:
        绘制了关键点的图片。
    """
    img = image_read(image)   # 读取图片
    data = data.cpu().numpy() if torch.is_tensor(data) else np.array(data)  # 将张量移动到CPU并转换为numpy数组

    # 绘制关键点
    for person in data:
        # 绘制连接线
        if connections:
            for start_idx, end_idx in connections:
                sta_point = person[start_idx]
                end_point = person[end_idx]
                if (sta_point[0] > 0 or sta_point[1] > 0) and (end_point[0] > 0 and end_point[1] > 0):  # 忽略无效点
                    cv2.line(img, (int(sta_point[0]), int(sta_point[1])),
                             (int(end_point[0]), int(end_point[1])), line_color, line_thickness)

        # 绘制关键点
        for point in person:
            x, y = point[:2]
            if x > 0 or y > 0:  # 忽略无效点
                cv2.circle(img, (int(x), int(y)), point_radius, point_color, -1)

    return img

获取一张图片的关键点

# test_pose.py

from ultralytics.task_bank.pose.predict import PosePredictor
from ultralytics.task_bank.pose.utils import get_upper_body_keypoint, plot_keypoint, image_show

video_path = r"path/to/*.mp4"
img_path = r"path/to/*.jpg"

overrides = {"task": "pose",
             "mode": "predict",
             "model": r'./weights/yolov8m-pose.pt',
             "verbose": False,
             "classes": [0]
             }

pose_predictor = PosePredictor(overrides=overrides)
res = pose_predictor(source=img_path)
data = get_upper_body_keypoint(res[0].keypoints.data)
"""
keypoints.data
[17, 3]: 17个关键点的x, y, confidence
keypoints.xy
[17, 2]: 仅含xy坐标
keypoints.nxy
[17, 2]: 根据图片大小归一化的xy坐标
"""

image_show(plot_keypoint(img_path, data))

3.使用一欧元滤波器

一欧元滤波

超参数小一点,缓解抖动才明显。

修改的地方:将源输入修改为张量,处理关键点丢失问题,YOLOv8关键点检测不到,会将其坐标变成(0, 0)。

如果当前帧检测到的是无效坐标,则直接变成(0, 0)。如果上一帧是无效坐标,则直接使用当前帧的检测结果。

修改方式比较粗暴,但个人觉得比较有效:难以检测到的关键点,大概率是低置信度or姿态遮挡导致,这样获取的坐标本身就无效。

import numpy as np


def smoothing_factor(t_e, cutoff):
    r = 2 * np.pi * cutoff * t_e
    return r / (r + 1)


def exponential_smoothing(a, x, x_prev):
    return a * x + (1 - a) * x_prev


class OneEuroFilter:
    def __init__(self, t0, x0, dx0=0.0, min_cutoff=0.1, beta=0.0, d_cutoff=0.1):
        """Initialize the one euro filter."""
        # The parameters.
        self.min_cutoff = float(min_cutoff)
        self.beta = float(beta)
        self.d_cutoff = float(d_cutoff)
        # Previous values.
        self.x_prev = np.array(x0, dtype=np.float32)
        self.dx_prev = np.zeros_like(self.x_prev) if dx0 is None else np.array(dx0, dtype=np.float32)
        self.t_prev = float(t0)

    def __call__(self, t, x):
        """Compute the filtered signal."""
        t_e = t - self.t_prev

        # The filtered derivative of the signal.
        a_d = smoothing_factor(t_e, self.d_cutoff)
        dx = (np.array(x) - self.x_prev) / t_e
        dx_hat = exponential_smoothing(a_d, dx, self.dx_prev)

        # The filtered signal.
        cutoff = self.min_cutoff + self.beta * np.abs(dx_hat)
        a = smoothing_factor(t_e, cutoff)
        x_hat = exponential_smoothing(a, x, self.x_prev)

        # 确保原始输入为0的值在结果中也保持为0。0表示无效,平滑后会出问题。
        zero_rows = np.all(np.array(x) == 0, axis=1)
        x_hat[zero_rows] = 0

        # 如果 self.x_prev 中某一行全部为0,则 x_hat 中的这一行设为输入 x 对应的值
        prev_zero_rows = np.all(self.x_prev == 0, axis=1)
        x_hat[prev_zero_rows] = np.array(x)[prev_zero_rows]

        # Memorize the previous values.
        self.x_prev = x_hat
        self.dx_prev = dx_hat
        self.t_prev = t

        return x_hat

使用目标跟踪获取每一个对象,并进行滤波

from ultralytics.task_bank.byte_tracker_modify import BYTETracker
from ultralytics.task_bank.pose.ops import filter_boxes_ioa
from ultralytics.task_bank.pose.utils import get_upper_body_keypoint
from ultralytics.task_bank.pose.one_euro_filter import OneEuroFilter
from easydict import EasyDict

bytetrack_config = {
    'track_high_thresh': 0.5,   # threshold for the first association
    'track_low_thresh': 0.1,    # threshold for the second association
    'new_track_thresh': 0.6,    # threshold for init new track if the detection does not match any tracks
    'track_buffer': 30,         # buffer to calculate the time when to remove tracks
    'match_thresh': 0.8         # threshold for matching tracks
}

bytetrack_args = EasyDict(bytetrack_config)


class Person:
    def __init__(self):
        self.bytetrack = BYTETracker(bytetrack_args)
        self.filter = dict()    # 一欧元滤波器字典{track_id: filter},缺少删除无效信息机制

    def update(self, idx_frame, pose):
        xyxy = pose.boxes.xyxy.cpu().numpy()
        conf = pose.boxes.conf.view(-1, 1).cpu().numpy()
        valid_index = filter_boxes_ioa(xyxy, conf)

        xyxy = xyxy[valid_index]
        conf = conf[valid_index]
        xywh = pose.boxes.xywh.cpu().numpy()[valid_index]
        cls = pose.boxes.cls.view(-1, 1).cpu().numpy()[valid_index]

        track_res = self.bytetrack.update(xywh, conf.reshape(-1), cls.reshape(-1))
        if track_res.shape[0] > 0:
            track_index = track_res[:, -1].astype(int)
        else:
            track_index = []

        keypoint_data = get_upper_body_keypoint(pose.keypoints.xy.cpu().numpy()[valid_index])
        # pre = keypoint_data.copy()
        for det in track_res:
            track_id, order = int(det[4]), int(det[-1])

            if track_id not in self.filter:
                self.filter[track_id] = OneEuroFilter(idx_frame, keypoint_data[order])
            else:
                keypoint_data[order] = self.filter[track_id](idx_frame, keypoint_data[order])

        # print(idx_frame)
        # print(pre - keypoint_data)
        return track_res, keypoint_data[track_index]

4.姿态估计的滤波结果

滤波前后的坐标变化:

上面的超参数下,最多修正10个像素。

# 使用滤波器前后关键点位置差值
[[[  -0.041962      1.8142]
  [   -0.13663    -0.95331]
  [     2.6572      1.7637]
  [    0.60077    -0.88373]
  [    0.28513     0.74567]
  [   -0.81116     0.50821]]

 [[     1.7402     0.43713]
  [     4.3165      3.4028]
  [     2.1268    -0.97229]
  [     3.4227      3.2691]
  [     1.6316     -3.0969]
  [     2.6895      2.7548]]

 [[    0.89527    -0.46735]
  [   0.042648    -0.77277]
  [     1.2155     -5.7796]
  [    -1.4052    -0.48135]
  [     2.3723     -11.068]
  [    0.81208    0.051392]]]

可视化结果:

变化差最多也就10像素,视频看起来会更明显平稳。

如下图红圈中,一帧前后的变化,加了滤波后抖动减小。