Programming Assignment 5: Kd-Trees
Write a data type to represent a set of points in the unit square (all points have x- and y-coordinates between 0 and 1) using a 2d-tree to support efficient range search(find all of the points contained in a query rectangle) and nearest-neighbor search (find a closest point to a query point).
Algorithms 第四周的编程任务,主要目的是编写一个PointSET
类,一个KdTree
类来实现相同的功能包括绘图,判断距离当前点最近的点,返回在指定范围内的点的集合。
方案
PointSET
类只需要封装一个Set
类型并调用相对应的方法即可。
实现代码
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.Queue;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.SET;
import edu.princeton.cs.algs4.StdDraw;
/**
* @author revc
*/
public class PointSET {
private SET<Point2D> points;
/**
* Construct an empty set of points.
*/
public PointSET() {
points = new SET<>();
}
/**
* Is the set empty?
*
* @return true if the set is empty, false otherwise.
*/
public boolean isEmpty() {
return points.isEmpty();
}
/**
* Number of points in the set
*
* @return number of points in the set.
*/
public int size() {
return points.size();
}
/**
* Add the point to the set (if it is not already in the set).
*/
public void insert(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
points.add(p);
}
/**
* Does the set contain point p?
*
* @return true if the set contains point p.
*/
public boolean contains(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
return points.contains(p);
}
/**
* Draw all points to standard draw.
*/
public void draw() {
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius(0.01);
for (Point2D point : points) {
point.draw();
}
}
/**
* Return all points that are inside the rectangle (or on the boundary).
*
* @return all points that are inside the rectangle (or on the boundary).
*/
public Iterable<Point2D> range(RectHV rect) {
if (rect == null) {
throw new IllegalArgumentException();
}
Queue<Point2D> queue = new Queue<>();
for (Point2D point : points) {
if (rect.contains(point)) {
queue.enqueue(point);
}
}
return queue;
}
/**
* Return a nearest neighbor in the set to point p; null if the set is empty
*
* @return a nearest neighbor in the set to point p; null if the set is empty
*/
public Point2D nearest(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
double currentDistance;
double nearestDistance = Double.MAX_VALUE;
Point2D nearestPoint = null;
for (Point2D point : points) {
currentDistance = point.distanceSquaredTo(p);
if (currentDistance < nearestDistance) {
nearestDistance = currentDistance;
nearestPoint = point;
}
}
return nearestPoint;
}
}
KdTree
采用二叉树结构,range()
方法的思考如下:
先从根节点开始算起,如果根结点的x,坐标系在矩形的两条边之间 根结点左右两侧都可能有点包含在矩形内,检查左右两侧 如果根结点的x轴处在矩形右侧,则只有左侧可能存在点 反之,根结点的x轴在矩形左侧,这只有右侧可能存在点 观察下一层以y轴为划分,如果当前y轴,处在矩形两条边之间,则当前节点上下又有可能有点 包含在矩形内部, 如果当前节点的y轴处在矩形的上方,那么只有下方可能存在点 反之,如果当前节点的y轴在矩形下方,那么只有上方可能存在点。
nearest()
方法思考如下:
假设查询点在左侧 让我们先从根节点开始处理,计算查询点的根结点的x的轴的相对距离,同时计算根结点与查询点的相对距离.如果根结点的距离小于等于查询点的根结点的x轴的相对距离,那么我们就不需要,在去查询x的右侧的点.记录当前距离为最短距离,否则两侧我们都需要去查,先查与查询点同侧的那一方记录,当前距离为最短距离. 接下来我们开始考虑第二层,计算查询点与当前节点的y轴的相对距离,同时计算当前节点与查询点的相对距离,如果最近距离小于等于y轴相对距离,那么,只需要看与查询点同一侧的点就可以了 反复
实现代码
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.Queue;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;
/**
* @author revc
*/
public class KdTree {
private static final boolean EVEN = false;
private Node root;
private int size;
/**
* Initialize a KdTree.
*/
public KdTree() {
}
/**
* Is this KdTree empty?
*
* @return true if this KdTree is empty, false otherwise.
*/
public boolean isEmpty() {
return size == 0;
}
/**
* Return the number of nodes in the KdTree.
*
* @return the number of nodes in the KdTree.
*/
public int size() {
return size;
}
/**
* Insert a point into KdTree.
*/
public void insert(Point2D p) {
if (p == null) {
throw new java.lang.IllegalArgumentException();
}
root = insert(root, p, EVEN);
}
private Node insert(Node currentNode, Point2D point, boolean parity) {
if (currentNode == null) {
++size;
return new Node(point);
}
if (parity == EVEN) {
if (point.equals(currentNode.point)) {
return currentNode;
}
double cmp = point.x() - currentNode.point.x();
if (cmp < 0) {
currentNode.left = insert(currentNode.left, point, !parity);
} else {
currentNode.right = insert(currentNode.right, point, !parity);
}
} else {
if (point.equals(currentNode.point)) {
return currentNode;
}
double cmp = point.y() - currentNode.point.y();
if (cmp < 0) {
currentNode.left = insert(currentNode.left, point, !parity);
} else {
currentNode.right = insert(currentNode.right, point, !parity);
}
}
return currentNode;
}
/**
* Check whether the KdTree contains {@code p}
*
* @return true if the KdTree contains {@code p}, false otherwise.
*/
public boolean contains(Point2D p) {
if (p == null) {
throw new java.lang.IllegalArgumentException();
}
return contains(root, p, EVEN);
}
private boolean contains(Node currentNode, Point2D point, boolean parity) {
if (currentNode == null) {
return false;
}
if (currentNode.point.compareTo(point) == 0) {
return true;
}
if (parity == EVEN) {
double cmp = point.x() - currentNode.point.x();
if (cmp < 0) {
return contains(currentNode.left, point, !parity);
} else {
return contains(currentNode.right, point, !parity);
}
} else {
double cmp = point.y() - currentNode.point.y();
if (cmp < 0) {
return contains(currentNode.left, point, !parity);
} else {
return contains(currentNode.right, point, !parity);
}
}
}
/**
* Drawing points and lines based on input.
*/
public void draw() {
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius(0.01);
drawPoint(root);
RectHV rectHV = new RectHV(0, 0, 1, 1);
StdDraw.setPenRadius();
drawLine(root, rectHV, EVEN);
}
private void drawLine(Node currentNode, RectHV rect, boolean parity) {
if (currentNode == null) {
return;
}
if (parity == EVEN) {
StdDraw.setPenColor(StdDraw.RED);
Point2D p1 = new Point2D(currentNode.point.x(), rect.ymin());
Point2D p2 = new Point2D(currentNode.point.x(), rect.ymax());
p1.drawTo(p2);
drawLine(currentNode.left,
new RectHV(rect.xmin(), rect.ymin(), currentNode.point.x(), rect.ymax()), !parity);
drawLine(currentNode.right,
new RectHV(currentNode.point.x(), rect.ymin(), rect.xmax(), rect.ymax()), !parity);
} else {
StdDraw.setPenColor(StdDraw.BLUE);
Point2D p1 = new Point2D(rect.xmin(), currentNode.point.y());
Point2D p2 = new Point2D(rect.xmax(), currentNode.point.y());
p1.drawTo(p2);
drawLine(currentNode.left,
new RectHV(rect.xmin(), rect.ymin(), rect.xmax(), currentNode.point.y()), !parity);
drawLine(currentNode.right,
new RectHV(rect.xmin(), currentNode.point.y(), rect.xmax(), rect.ymax()), !parity);
}
}
private void drawPoint(Node currentNode) {
if (currentNode == null) {
return;
}
currentNode.point.draw();
drawPoint(currentNode.left);
drawPoint(currentNode.right);
}
/**
* Return all the points within the range.
*
* @return all the points within the range.
*/
public Iterable<Point2D> range(RectHV rect) {
if (rect == null) {
throw new java.lang.IllegalArgumentException();
}
Queue<Point2D> queue = new Queue<>();
range(root, queue, rect, false);
return queue;
}
private void range(Node currentNode, Queue<Point2D> queue, RectHV rect, boolean parity) {
if (currentNode == null) {
return;
}
if (rect.contains(currentNode.point)) {
queue.enqueue(currentNode.point);
}
//by x
if (parity == EVEN) {
//only left
if (currentNode.point.x() > rect.xmax()) {
range(currentNode.left, queue, rect, !parity);
}
//only right
else if (currentNode.point.x() <= rect.xmin()) {
range(currentNode.right, queue, rect, !parity);
}
//both
else {
range(currentNode.left, queue, rect, !parity);
range(currentNode.right, queue, rect, !parity);
}
}
//by y
else {
//only below
if (currentNode.point.y() > rect.ymax()) {
range(currentNode.left, queue, rect, !parity);
}
//only top
else if (currentNode.point.y() <= rect.ymin()) {
range(currentNode.right, queue, rect, !parity);
}
//both
else {
range(currentNode.left, queue, rect, !parity);
range(currentNode.right, queue, rect, !parity);
}
}
}
/**
* Return the nearest point.
*
* @return the nearest point.
*/
public Point2D nearest(Point2D p) {
if (p == null) {
throw new java.lang.IllegalArgumentException();
}
return nearest(root, p, Double.MAX_VALUE, null, false);
}
private Point2D nearest(Node currentNode, Point2D p, double nearestDistance, Point2D nearestPoint, boolean parity) {
if (currentNode == null) {
return nearestPoint;
}
if (parity == EVEN) {
double xDiff = p.x() - currentNode.point.x();
//If the relative distance of the x axis is larger than the nearest distance,
// are not the distance between two points is calculated
if (nearestDistance >= xDiff * xDiff) {
double currDis = p.distanceSquaredTo(currentNode.point);
//check and update
if (currDis < nearestDistance) {
nearestDistance = currDis;
nearestPoint = currentNode.point;
}
}
if (nearestDistance < xDiff * xDiff) {
//only left
if (xDiff < 0) {
return nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);
}
//only right
else {
return nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
}
} else {
if (xDiff < 0) {
nearestPoint = nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);
nearestDistance = nearestPoint.distanceSquaredTo(p);
return nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
} else {
nearestPoint = nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
nearestDistance = nearestPoint.distanceSquaredTo(p);
return nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);
}
}
} else {
double yDiff = p.y() - currentNode.point.y();
if (nearestDistance >= yDiff * yDiff) {
double currDis = p.distanceSquaredTo(currentNode.point);
//check and update
if (currDis < nearestDistance) {
nearestDistance = currDis;
nearestPoint = currentNode.point;
}
}
if (nearestDistance < yDiff * yDiff) {
//only below
if (yDiff < 0) {
return nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);
}
//only top
else {
return nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
}
} else {
if (yDiff < 0) {
nearestPoint = nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);
nearestDistance = nearestPoint.distanceSquaredTo(p);
return nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
} else {
nearestPoint = nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
nearestDistance = nearestPoint.distanceSquaredTo(p);
return nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);
}
}
}
}
private class Node {
private Point2D point;
private Node left;
private Node right;
public Node(Point2D point) {
this.point = point;
}
}
}