参考文章
1.基本原理
一个低通滤波器,设计自适应公式,对历史信息进行加权平均。
2.YOLOv8-pose进行姿态估计
处理关键点脚本,主要需要手臂关键点,用于获取指定关键点并绘制图片:
import cv2
import torch
import numpy as np
COCO_keypoint_indices = {
0: 'nose',
1: 'left_eye',
2: 'right_eye',
3: 'left_ear',
4: 'right_ear',
5: 'left_shoulder',
6: 'right_shoulder',
7: 'left_elbow',
8: 'right_elbow',
9: 'left_wrist',
10: 'right_wrist',
11: 'left_hip',
12: 'right_hip',
13: 'left_knee',
14: 'right_knee',
15: 'left_ankle',
16: 'right_ankle'
}
COCO_DEFAULT_UPPER_BODY_KEYPOINT_INDICES = (5, 6, 7, 8, 9, 10) # 上半身的关键点索引
COCO_DEFAULT_CONNECTIONS = ((4, 2), (2, 0), (0, 1), (1, 3), (3, 5)) # 关键点连接顺序(例如:0连接1,1连接2,依此类推)
default_up_body_indices = {
0: 'left_shoulder',
1: 'right_shoulder',
2: 'left_elbow',
3: 'right_elbow',
4: 'left_wrist',
5: 'right_wrist',
}
def get_upper_body_keypoint(data, keypoint_indices=COCO_DEFAULT_UPPER_BODY_KEYPOINT_INDICES):
# 检查数据是否为空或大小为零
if data is None or data.shape[0] == 0 or data.shape[1] == 0:
# 返回一个空的数组,形状为 (0, len(keypoint_indices), 3)
return np.empty((0, len(keypoint_indices), 3))
# 检查 keypoint_indices 是否超出数据的范围
if max(keypoint_indices) >= data.shape[1]:
raise IndexError("Keypoint indices are out of bounds for the given data")
return data[:, keypoint_indices, :]
def image_read(image):
# 如果 image 是字符串,则尝试读取路径
if isinstance(image, str):
img = cv2.imread(image)
if img is None:
raise ValueError(f"无法读取图像路径: {image}")
elif isinstance(image, np.ndarray):
img = image
else:
raise TypeError("image 参数应为字符串路径或 numpy 数组")
return img
def image_show(image, desc="KeyPoint"):
cv2.imshow(desc, image)
cv2.waitKey(0)
cv2.destroyAllWindows()
def plot_keypoint(image, data, connections=COCO_DEFAULT_CONNECTIONS, point_color=(0, 0, 255), point_radius=4,
line_color=(0, 255, 0), line_thickness=2):
"""
在图片上绘制关键点和连线。
Args:
image: 图片源
data: YOLOv8姿态检测结果
connections: 连线顺讯
point_color: 关键点的颜色
point_radius: 关键点的大小
line_color: 连线的颜色
line_thickness: 连线的粗细
Returns:
绘制了关键点的图片。
"""
img = image_read(image) # 读取图片
data = data.cpu().numpy() if torch.is_tensor(data) else np.array(data) # 将张量移动到CPU并转换为numpy数组
# 绘制关键点
for person in data:
# 绘制连接线
if connections:
for start_idx, end_idx in connections:
sta_point = person[start_idx]
end_point = person[end_idx]
if (sta_point[0] > 0 or sta_point[1] > 0) and (end_point[0] > 0 and end_point[1] > 0): # 忽略无效点
cv2.line(img, (int(sta_point[0]), int(sta_point[1])),
(int(end_point[0]), int(end_point[1])), line_color, line_thickness)
# 绘制关键点
for point in person:
x, y = point[:2]
if x > 0 or y > 0: # 忽略无效点
cv2.circle(img, (int(x), int(y)), point_radius, point_color, -1)
return img
获取一张图片的关键点:
# test_pose.py
from ultralytics.task_bank.pose.predict import PosePredictor
from ultralytics.task_bank.pose.utils import get_upper_body_keypoint, plot_keypoint, image_show
video_path = r"path/to/*.mp4"
img_path = r"path/to/*.jpg"
overrides = {"task": "pose",
"mode": "predict",
"model": r'./weights/yolov8m-pose.pt',
"verbose": False,
"classes": [0]
}
pose_predictor = PosePredictor(overrides=overrides)
res = pose_predictor(source=img_path)
data = get_upper_body_keypoint(res[0].keypoints.data)
"""
keypoints.data
[17, 3]: 17个关键点的x, y, confidence
keypoints.xy
[17, 2]: 仅含xy坐标
keypoints.nxy
[17, 2]: 根据图片大小归一化的xy坐标
"""
image_show(plot_keypoint(img_path, data))
3.使用一欧元滤波器
一欧元滤波:
超参数小一点,缓解抖动才明显。
修改的地方:将源输入修改为张量,处理关键点丢失问题,YOLOv8关键点检测不到,会将其坐标变成(0, 0)。
如果当前帧检测到的是无效坐标,则直接变成(0, 0)。如果上一帧是无效坐标,则直接使用当前帧的检测结果。
修改方式比较粗暴,但个人觉得比较有效:难以检测到的关键点,大概率是低置信度or姿态遮挡导致,这样获取的坐标本身就无效。
import numpy as np
def smoothing_factor(t_e, cutoff):
r = 2 * np.pi * cutoff * t_e
return r / (r + 1)
def exponential_smoothing(a, x, x_prev):
return a * x + (1 - a) * x_prev
class OneEuroFilter:
def __init__(self, t0, x0, dx0=0.0, min_cutoff=0.1, beta=0.0, d_cutoff=0.1):
"""Initialize the one euro filter."""
# The parameters.
self.min_cutoff = float(min_cutoff)
self.beta = float(beta)
self.d_cutoff = float(d_cutoff)
# Previous values.
self.x_prev = np.array(x0, dtype=np.float32)
self.dx_prev = np.zeros_like(self.x_prev) if dx0 is None else np.array(dx0, dtype=np.float32)
self.t_prev = float(t0)
def __call__(self, t, x):
"""Compute the filtered signal."""
t_e = t - self.t_prev
# The filtered derivative of the signal.
a_d = smoothing_factor(t_e, self.d_cutoff)
dx = (np.array(x) - self.x_prev) / t_e
dx_hat = exponential_smoothing(a_d, dx, self.dx_prev)
# The filtered signal.
cutoff = self.min_cutoff + self.beta * np.abs(dx_hat)
a = smoothing_factor(t_e, cutoff)
x_hat = exponential_smoothing(a, x, self.x_prev)
# 确保原始输入为0的值在结果中也保持为0。0表示无效,平滑后会出问题。
zero_rows = np.all(np.array(x) == 0, axis=1)
x_hat[zero_rows] = 0
# 如果 self.x_prev 中某一行全部为0,则 x_hat 中的这一行设为输入 x 对应的值
prev_zero_rows = np.all(self.x_prev == 0, axis=1)
x_hat[prev_zero_rows] = np.array(x)[prev_zero_rows]
# Memorize the previous values.
self.x_prev = x_hat
self.dx_prev = dx_hat
self.t_prev = t
return x_hat
使用目标跟踪获取每一个对象,并进行滤波:
from ultralytics.task_bank.byte_tracker_modify import BYTETracker
from ultralytics.task_bank.pose.ops import filter_boxes_ioa
from ultralytics.task_bank.pose.utils import get_upper_body_keypoint
from ultralytics.task_bank.pose.one_euro_filter import OneEuroFilter
from easydict import EasyDict
bytetrack_config = {
'track_high_thresh': 0.5, # threshold for the first association
'track_low_thresh': 0.1, # threshold for the second association
'new_track_thresh': 0.6, # threshold for init new track if the detection does not match any tracks
'track_buffer': 30, # buffer to calculate the time when to remove tracks
'match_thresh': 0.8 # threshold for matching tracks
}
bytetrack_args = EasyDict(bytetrack_config)
class Person:
def __init__(self):
self.bytetrack = BYTETracker(bytetrack_args)
self.filter = dict() # 一欧元滤波器字典{track_id: filter},缺少删除无效信息机制
def update(self, idx_frame, pose):
xyxy = pose.boxes.xyxy.cpu().numpy()
conf = pose.boxes.conf.view(-1, 1).cpu().numpy()
valid_index = filter_boxes_ioa(xyxy, conf)
xyxy = xyxy[valid_index]
conf = conf[valid_index]
xywh = pose.boxes.xywh.cpu().numpy()[valid_index]
cls = pose.boxes.cls.view(-1, 1).cpu().numpy()[valid_index]
track_res = self.bytetrack.update(xywh, conf.reshape(-1), cls.reshape(-1))
if track_res.shape[0] > 0:
track_index = track_res[:, -1].astype(int)
else:
track_index = []
keypoint_data = get_upper_body_keypoint(pose.keypoints.xy.cpu().numpy()[valid_index])
# pre = keypoint_data.copy()
for det in track_res:
track_id, order = int(det[4]), int(det[-1])
if track_id not in self.filter:
self.filter[track_id] = OneEuroFilter(idx_frame, keypoint_data[order])
else:
keypoint_data[order] = self.filter[track_id](idx_frame, keypoint_data[order])
# print(idx_frame)
# print(pre - keypoint_data)
return track_res, keypoint_data[track_index]
4.姿态估计的滤波结果
滤波前后的坐标变化:
上面的超参数下,最多修正10个像素。
# 使用滤波器前后关键点位置差值
[[[ -0.041962 1.8142]
[ -0.13663 -0.95331]
[ 2.6572 1.7637]
[ 0.60077 -0.88373]
[ 0.28513 0.74567]
[ -0.81116 0.50821]]
[[ 1.7402 0.43713]
[ 4.3165 3.4028]
[ 2.1268 -0.97229]
[ 3.4227 3.2691]
[ 1.6316 -3.0969]
[ 2.6895 2.7548]]
[[ 0.89527 -0.46735]
[ 0.042648 -0.77277]
[ 1.2155 -5.7796]
[ -1.4052 -0.48135]
[ 2.3723 -11.068]
[ 0.81208 0.051392]]]
可视化结果:
变化差最多也就10像素,视频看起来会更明显平稳。
如下图红圈中,一帧前后的变化,加了滤波后抖动减小。