Coursera Algorithms第五周编程任务

Programming Assignment 5: Kd-Trees

Write a data type to represent a set of points in the unit square (all points have x- and y-coordinates between 0 and 1) using a 2d-tree to support efficient range search(find all of the points contained in a query rectangle) and nearest-neighbor search (find a closest point to a query point). 

Algorithms 第四周的编程任务,主要目的是编写一个PointSET类,一个KdTree类来实现相同的功能包括绘图,判断距离当前点最近的点,返回在指定范围内的点的集合。

方案

PointSET类只需要封装一个Set类型并调用相对应的方法即可。

实现代码

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.Queue;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.SET;
import edu.princeton.cs.algs4.StdDraw;

/**
 * @author revc
 */
public class PointSET {
    private SET<Point2D> points;

    /**
     * Construct an empty set of points.
     */
    public PointSET() {
        points = new SET<>();
    }

    /**
     * Is the set empty?
     *
     * @return true if the set is empty, false otherwise.
     */
    public boolean isEmpty() {
        return points.isEmpty();
    }

    /**
     *  Number of points in the set
     *
     * @return number of points in the set.
     */
    public int size() {
        return points.size();
    }

    /**
     * Add the point to the set (if it is not already in the set).
     */
    public void insert(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        points.add(p);
    }

    /**
     * Does the set contain point p?
     *
     * @return true if the set contains point p.
     */
    public boolean contains(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        return points.contains(p);
    }

    /**
     * Draw all points to standard draw.
     */
    public void draw() {
        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.setPenRadius(0.01);
        for (Point2D point : points) {
            point.draw();
        }
    }

    /**
     * Return all points that are inside the rectangle (or on the boundary).
     *
     * @return all points that are inside the rectangle (or on the boundary).
     */
    public Iterable<Point2D> range(RectHV rect) {
        if (rect == null) {
            throw new IllegalArgumentException();
        }

        Queue<Point2D> queue = new Queue<>();

        for (Point2D point : points) {
            if (rect.contains(point)) {
                queue.enqueue(point);
            }
        }
        return queue;
    }

    /**
     * Return a nearest neighbor in the set to point p; null if the set is empty
     *
     * @return a nearest neighbor in the set to point p; null if the set is empty
     */
    public Point2D nearest(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }

        double currentDistance;
        double nearestDistance = Double.MAX_VALUE;
        Point2D nearestPoint = null;

        for (Point2D point : points) {
            currentDistance = point.distanceSquaredTo(p);
            if (currentDistance < nearestDistance) {
                nearestDistance = currentDistance;
                nearestPoint = point;
            }
        }
        return nearestPoint;
    }
}

KdTree采用二叉树结构,range()方法的思考如下:

    先从根节点开始算起,如果根结点的x,坐标系在矩形的两条边之间
    根结点左右两侧都可能有点包含在矩形内,检查左右两侧
    如果根结点的x轴处在矩形右侧,则只有左侧可能存在点
    反之,根结点的x轴在矩形左侧,这只有右侧可能存在点

    观察下一层以y轴为划分,如果当前y轴,处在矩形两条边之间,则当前节点上下又有可能有点
    包含在矩形内部,
    如果当前节点的y轴处在矩形的上方,那么只有下方可能存在点
    反之,如果当前节点的y轴在矩形下方,那么只有上方可能存在点。

nearest()方法思考如下:

    假设查询点在左侧
    让我们先从根节点开始处理,计算查询点的根结点的x的轴的相对距离,同时计算根结点与查询点的相对距离.如果根结点的距离小于等于查询点的根结点的x轴的相对距离,那么我们就不需要,在去查询x的右侧的点.记录当前距离为最短距离,否则两侧我们都需要去查,先查与查询点同侧的那一方记录,当前距离为最短距离.

    接下来我们开始考虑第二层,计算查询点与当前节点的y轴的相对距离,同时计算当前节点与查询点的相对距离,如果最近距离小于等于y轴相对距离,那么,只需要看与查询点同一侧的点就可以了
    反复

实现代码

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.Queue;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;

/**
 * @author revc
 */
public class KdTree {
    private static final boolean EVEN = false;

    private Node root;
    private int size;

    /**
     * Initialize a KdTree.
     */
    public KdTree() {

    }

    /**
     * Is this KdTree empty?
     *
     * @return true if this KdTree is empty, false otherwise.
     */
    public boolean isEmpty() {
        return size == 0;
    }

    /**
     * Return the number of nodes in the KdTree.
     *
     * @return the number of nodes in the KdTree.
     */
    public int size() {
        return size;
    }

    /**
     * Insert a point into KdTree.
     */
    public void insert(Point2D p) {
        if (p == null) {
            throw new java.lang.IllegalArgumentException();
        }

        root = insert(root, p, EVEN);
    }

    private Node insert(Node currentNode, Point2D point, boolean parity) {
        if (currentNode == null) {
            ++size;
            return new Node(point);
        }

        if (parity == EVEN) {
            if (point.equals(currentNode.point)) {
                return currentNode;
            }

            double cmp = point.x() - currentNode.point.x();
            if (cmp < 0) {
                currentNode.left = insert(currentNode.left, point, !parity);
            } else {
                currentNode.right = insert(currentNode.right, point, !parity);
            }

        } else {

            if (point.equals(currentNode.point)) {
                return currentNode;
            }

            double cmp = point.y() - currentNode.point.y();
            if (cmp < 0) {
                currentNode.left = insert(currentNode.left, point, !parity);
            } else {
                currentNode.right = insert(currentNode.right, point, !parity);
            }
        }
        return currentNode;
    }

    /**
     * Check whether the KdTree contains {@code p}
     *
     * @return true if the KdTree contains {@code p}, false otherwise.
     */
    public boolean contains(Point2D p) {
        if (p == null) {
            throw new java.lang.IllegalArgumentException();
        }
        return contains(root, p, EVEN);
    }

    private boolean contains(Node currentNode, Point2D point, boolean parity) {
        if (currentNode == null) {
            return false;
        }

        if (currentNode.point.compareTo(point) == 0) {
            return true;
        }

        if (parity == EVEN) {

            double cmp = point.x() - currentNode.point.x();
            if (cmp < 0) {
                return contains(currentNode.left, point, !parity);
            } else {
                return contains(currentNode.right, point, !parity);
            }

        } else {

            double cmp = point.y() - currentNode.point.y();
            if (cmp < 0) {
                return contains(currentNode.left, point, !parity);
            } else {
                return contains(currentNode.right, point, !parity);
            }

        }
    }

    /**
     * Drawing points and lines based on input.
     */
    public void draw() {
        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.setPenRadius(0.01);
        drawPoint(root);

        RectHV rectHV = new RectHV(0, 0, 1, 1);
        StdDraw.setPenRadius();
        drawLine(root, rectHV, EVEN);
    }

    private void drawLine(Node currentNode, RectHV rect, boolean parity) {
        if (currentNode == null) {
            return;
        }

        if (parity == EVEN) {

            StdDraw.setPenColor(StdDraw.RED);
            Point2D p1 = new Point2D(currentNode.point.x(), rect.ymin());
            Point2D p2 = new Point2D(currentNode.point.x(), rect.ymax());

            p1.drawTo(p2);

            drawLine(currentNode.left,
                    new RectHV(rect.xmin(), rect.ymin(), currentNode.point.x(), rect.ymax()), !parity);

            drawLine(currentNode.right,
                    new RectHV(currentNode.point.x(), rect.ymin(), rect.xmax(), rect.ymax()), !parity);
        } else {

            StdDraw.setPenColor(StdDraw.BLUE);
            Point2D p1 = new Point2D(rect.xmin(), currentNode.point.y());
            Point2D p2 = new Point2D(rect.xmax(), currentNode.point.y());

            p1.drawTo(p2);

            drawLine(currentNode.left,
                    new RectHV(rect.xmin(), rect.ymin(), rect.xmax(), currentNode.point.y()), !parity);
            drawLine(currentNode.right,
                    new RectHV(rect.xmin(), currentNode.point.y(), rect.xmax(), rect.ymax()), !parity);
        }
    }

    private void drawPoint(Node currentNode) {
        if (currentNode == null) {
            return;
        }
        currentNode.point.draw();
        drawPoint(currentNode.left);
        drawPoint(currentNode.right);
    }

    /**
     * Return all the points within the range.
     *
     * @return all the points within the range.
     */
    public Iterable<Point2D> range(RectHV rect) {
        if (rect == null) {
            throw new java.lang.IllegalArgumentException();
        }
        Queue<Point2D> queue = new Queue<>();
        range(root, queue, rect, false);
        return queue;
    }

    private void range(Node currentNode, Queue<Point2D> queue, RectHV rect, boolean parity) {
        if (currentNode == null) {
            return;
        }
        if (rect.contains(currentNode.point)) {
            queue.enqueue(currentNode.point);
        }

        //by x
        if (parity == EVEN) {
            //only left
            if (currentNode.point.x() > rect.xmax()) {
                range(currentNode.left, queue, rect, !parity);
            }
            //only right
            else if (currentNode.point.x() <= rect.xmin()) {
                range(currentNode.right, queue, rect, !parity);
            }
            //both
            else {
                range(currentNode.left, queue, rect, !parity);
                range(currentNode.right, queue, rect, !parity);
            }
        }

        //by y
        else {
            //only below
            if (currentNode.point.y() > rect.ymax()) {
                range(currentNode.left, queue, rect, !parity);
            }
            //only top
            else if (currentNode.point.y() <= rect.ymin()) {
                range(currentNode.right, queue, rect, !parity);
            }
            //both
            else {
                range(currentNode.left, queue, rect, !parity);
                range(currentNode.right, queue, rect, !parity);
            }
        }
    }

    /**
     * Return the nearest point.
     *
     * @return the nearest point.
     */
    public Point2D nearest(Point2D p) {
        if (p == null) {
            throw new java.lang.IllegalArgumentException();
        }
        return nearest(root, p, Double.MAX_VALUE, null, false);
    }

    private Point2D nearest(Node currentNode, Point2D p, double nearestDistance, Point2D nearestPoint, boolean parity) {
        if (currentNode == null) {
            return nearestPoint;
        }
        if (parity == EVEN) {
            double xDiff = p.x() - currentNode.point.x();

            //If the relative distance of the x axis is larger than the nearest distance,
            // are not the distance between two points is calculated
            if (nearestDistance >= xDiff * xDiff) {
                double currDis = p.distanceSquaredTo(currentNode.point);
                //check and update
                if (currDis < nearestDistance) {
                    nearestDistance = currDis;
                    nearestPoint = currentNode.point;
                }
            }

            if (nearestDistance < xDiff * xDiff) {
                //only left
                if (xDiff < 0) {
                    return nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);
                }
                //only right
                else {
                    return nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
                }
            } else {
                if (xDiff < 0) {
                    nearestPoint = nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);
                    nearestDistance = nearestPoint.distanceSquaredTo(p);
                    return nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
                } else {
                    nearestPoint = nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
                    nearestDistance = nearestPoint.distanceSquaredTo(p);
                    return nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);

                }
            }
        } else {
            double yDiff = p.y() - currentNode.point.y();
            if (nearestDistance >= yDiff * yDiff) {
                double currDis = p.distanceSquaredTo(currentNode.point);
                //check and update
                if (currDis < nearestDistance) {
                    nearestDistance = currDis;
                    nearestPoint = currentNode.point;
                }
            }

            if (nearestDistance < yDiff * yDiff) {
                //only below
                if (yDiff < 0) {
                    return nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);
                }
                //only top
                else {
                    return nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
                }
            } else {
                if (yDiff < 0) {
                    nearestPoint = nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);
                    nearestDistance = nearestPoint.distanceSquaredTo(p);
                    return nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
                } else {
                    nearestPoint = nearest(currentNode.right, p, nearestDistance, nearestPoint, !parity);
                    nearestDistance = nearestPoint.distanceSquaredTo(p);
                    return nearest(currentNode.left, p, nearestDistance, nearestPoint, !parity);

                }
            }
        }
    }

    private class Node {
        private Point2D point;
        private Node left;
        private Node right;

        public Node(Point2D point) {
            this.point = point;
        }
    }
}

猜你喜欢

转载自www.cnblogs.com/revc/p/9239294.html
今日推荐