KdTree原理及实现

Kd-tree是一种基于二叉搜索树的数据结构,用于高效地查找多维空间中的数据点。Kd-tree把多维空间划分成一系列的超矩形(即“单元格”),并将数据点存储在这些单元格中。每个单元格根据其所属的维度选择一个坐标轴,将其划分为两个子单元格,并将数据点存储在相应的子单元格中。

以下是Kd-tree的原理图示例,在二维平面上展示了一个Kd-tree的构建过程:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Necvbggz-1685103426668)(null)]

在这个例子中,我们需要构建一个包含6个数据点的Kd-tree。首先,我们将数据点集合放在一个最小矩形框中(即包围盒),并选择一个坐标轴来划分这个矩形。在本例中,我们选择了x轴,并将数据点按照x轴坐标进行排序。我们选择了x坐标在中间的一个点(即第4个点,也就是坐标为(6, 7)的点)作为根节点,并将其所在的矩形划分为两个子矩形:一个左子矩形(包含前3个数据点)和一个右子矩形(包含后3个数据点)。接下来,我们选择y轴来划分左右子矩形。在左子矩形中,我们选择y坐标在中间的点(即第2个点,也就是坐标为(2, 3)的点)作为左子节点,并将左子矩形划分为两个子矩形:一个左子矩形(包含第1个数据点)和一个右子矩形(包含第2个数据点)。在右子矩形中,我们选择y坐标在中间的点(即第5个点,也就是坐标为(8, 2)的点)作为右子节点,并将右子矩形划分为两个子矩形:一个左子矩形(包含第4个数据点)和一个右子矩形(包含第6个数据点)。最终,我们得到了一个包含6个数据点的Kd-tree。

下面是一个简单的KD-Tree Python实现:

import numpy as np

class Node:
    def __init__(self, point, d, left=None, right=None):
        self.point = point
        self.d = d
        self.left = left
        self.right = right

class KDTree:
    def __init__(self, points):
        def buildTree(points, d):
            if len(points) == 0:
                return None
            points = sorted(points, key=lambda x: x[d])
            mid = len(points) // 2
            return Node(
                point=points[mid],
                d=d,
                left=buildTree(points[:mid], (d + 1) % len(points[0])),
                right=buildTree(points[mid + 1:], (d + 1) % len(points[0]))
            )

        self.root = buildTree(points, 0)

    def search(self, target):
        def searchHelper(node, target, best):
            if node is None:
                return best, np.inf
            if node.point == target:
                return node.point, 0
            dist = np.linalg.norm(np.array(node.point) - np.array(target))
            if dist < best[1]:
                best = node.point, dist
            diff = target[node.d] - node.point[node.d]
            if diff <= best[1]:
                best, _ = searchHelper(node.left if diff < 0 else node.right, target, best)
            return searchHelper(node.right if diff < 0 else node.left, target, best)

        return searchHelper(self.root, target, (None, np.inf))[0]

这里的Node类代表了KD-Tree中的一个节点,每个节点都具有一个点和一个分割维度。KDTree类则包含了根节点。其中的buildTree方法是递归的建树函数,它根据分割维度将点集排序,然后选择中位数作为根节点,递归构建左右子树。search方法则是搜索函数,采用递归的方式寻找目标点的最近邻点。在寻找过程中,首先计算目标点到当前节点的距离,如果小于已知的最小距离,则更新最小距离和最近邻点。然后按照当前节点的分割维度将目标点与当前节点的点进行比较,选择相应的子树继续递归搜索,直到遇到空节点为止。最后返回最近邻点即可。

KD树是一种用于快速查找k维空间中最近邻点的数据结构。它是对空间的一种划分,每一次划分会将空间划分为两个部分,同时将数据点也分为两部分,以此递归下去直至划分完成。在实际应用中,KD树可以用于空间中最近点查询、range查询和k近邻查询。

下面是一个简单的KD树实现:

首先,定义一个Node结构体表示一个节点:

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

struct Node {
    vector<double> point; // 点的坐标
    int axis; // 划分轴
    Node *left; // 左子树
    Node *right; // 右子树
    
    Node(vector<double> p, int a) : point(p), axis(a), left(nullptr), right(nullptr) {}
};

接着,定义一个KDTree结构体表示整棵KD树:

struct KDTree {
    Node *root;
    
    KDTree() : root(nullptr) {}
    
    // 插入一个点
    void insert(vector<double> point);
    
    // 最近邻查询
    double nearestNeighbor(vector<double> point);
    
    // 内部函数
    Node *insertNode(Node *node, vector<double> point, int depth);
    double distance(vector<double> p1, vector<double> p2);
    double distanceSquared(vector<double> p1, vector<double> p2);
    double nearestNeighbor(Node *node, vector<double> point, double best);
};

void KDTree::insert(vector<double> point) {
    root = insertNode(root, point, 0);
}

Node *KDTree::insertNode(Node *node, vector<double> point, int depth) {
    if (!node) {
        return new Node(point, depth % point.size());
    }
    if (point[node->axis] < node->point[node->axis]) {
        node->left = insertNode(node->left, point, depth + 1);
    } else {
        node->right = insertNode(node->right, point, depth + 1);
    }
    return node;
}

double KDTree::nearestNeighbor(vector<double> point) {
    return nearestNeighbor(root, point, DBL_MAX);
}

double KDTree::distanceSquared(vector<double> p1, vector<double> p2) {
    double distance = 0;
    for (int i = 0; i < p1.size(); ++i) {
        distance += (p1[i] - p2[i]) * (p1[i] - p2[i]);
    }
    return distance;
}

double KDTree::distance(vector<double> p1, vector<double> p2) {
    return sqrt(distanceSquared(p1, p2));
}

double KDTree::nearestNeighbor(Node *node, vector<double> point, double best) {
    if (!node) {
        return best;
    }
    double dist = distance(node->point, point);
    if (dist < best) {
        best = dist;
    }
    if (point[node->axis] < node->point[node->axis]) {
        best = nearestNeighbor(node->left, point, best);
        if (point[node->axis] + best > node->point[node->axis]) {
            best = nearestNeighbor(node->right, point, best);
        }
    } else {
        best = nearestNeighbor(node->right, point, best);
        if (point[node->axis] - best < node->point[node->axis]) {
            best = nearestNeighbor(node->left, point, best);
        }
    }
    return best;
}

下面是一个简单的测试:

int main() {
    KDTree kdTree;
    kdTree.insert({-3, 2});
    kdTree.insert({-2, 5});
    kdTree.insert({0, 0});
    kdTree.insert({4, 3});
    kdTree.insert({1, 1});
    
    cout << kdTree.nearestNeighbor({2, 3}) << endl; // 输出2.23607
    return 0;
}

在这个例子中,我们插入了5个二维点,然后查找离点(2,3)最近的点,输出结果为2.23607。

猜你喜欢

转载自blog.csdn.net/qq_39506862/article/details/127978420