Algorithms, Part I - Week 5 - Kd-Trees
来自 Coursera 上普林斯顿大学的 Algorithms, Part I 课程的第五周编程作业Kd-Trees
分析
编写数据类型以表示单位正方形中的一组点 (所有点都具有0和1之间的 x 和 y 坐标)
使用2维树进行范围搜索 (查找查询矩形中包含的所有点) 和最近的邻居搜索 (查找与查询点最近的点)
2维树有许多应用,从将天文物体分类到计算机动画,再到加速神经网络,挖掘数据到图像检索。
Geometric primitives. 定义了图元的坐标表示方法
主要是完成 API 里的 draw(), range(), nearest() 方法
题目给了 Point2D, RectHV 这两种数据类型以使用,要提交的 PointSET.java 使用蛮力运算,KdTree.java 则使用 2d-Trees 来运算,题目就是完成这两个文件
Node data type
根据CheckList所给的结点参考模型进行一些修改
private static class Node {
private boolean isVertical;
private Point2D p;
// the left/bottom subtree
private Node lb;
// the right/top subtree
private Node rt;
public Node(Point2D p, boolean isVertical) {
this.p = p;
this.isVertical = isVertical;
this.lb = null;
this.rt = null;
}
}
Writing KdTree
可以从编写 isEmpty() 和 size() 开始,再写 insert(),再 写 contains() 并测试 insert()是否可用
要注意的是, insert() 和 contains() 的写法要使用 private helper methods(书399页),并增加一个 boolean orientation 作为这些帮助器方法的参数
2d-tree implementation. 编写一个可变数据类型 KdTree, 它使用2维树实现和蛮力运算相同的 api
注意的是,垂直分割线段是红色,垂直分割线段是蓝色,第一条一定是垂直分割,所以用一个布尔变量 isVertical 来识别线段颜色
Range search. 思想在课件里给出了,主要就是递归以,而且要分4种情况(垂直线左,垂直线右,水平线下,水平线上)进行回溯剪枝,这样可以排除不可能的子树
Nearest neighbor search. 从根结点开始递归搜索,更新候选点,同样使用剪枝
答案
PointSET.java
public class PointSET {
private SET<Point2D> set;
private ArrayList<Point2D> result;
// construct an empty set of points
public PointSET() {
set = new SET<Point2D>();
}
// is the set empty?
public boolean isEmpty() {
return set.size() == 0;
}
// number of points in the set
public int size() {
return set.size();
}
// add the point to the set (if it is not already in the set)
public void insert(Point2D p) {
if (p == null) throw new java.lang.NullPointerException();
if (!set.contains(p)) set.add(p);
}
// does the set contain point p?
public boolean contains(Point2D p) {
if (p == null)
throw new java.lang.NullPointerException();
return set.contains(p);
}
// draw all points to standard draw
public void draw() {
// drawing the points
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius(0.01);
for (Point2D point : set) {
StdDraw.point(point.x(), point.y());
}
}
// all points that are inside the rectangle
public Iterable<Point2D> range(RectHV rect) {
if (rect == null) {
throw new java.lang.NullPointerException();
}
result = new ArrayList<Point2D>();
for (Point2D point : set) {
if (rect.contains(point))
result.add(point);
}
return result;
}
// a nearest neighbor in the set to point p; null if the set is empty
public Point2D nearest(Point2D p) {
if (p == null) throw new java.lang.NullPointerException();
if (set.isEmpty()) return null;
double min = Double.MAX_VALUE;
Point2D result = null;
for (Point2D point : set) {
double distence = Point.distanceSq(p.x(), p.y(), point.x(), point.y());
if (distence < min) {
min = distence;
result = point;
}
}
return result;
}
// unit testing of the methods (optional)
public static void main(String[] args) {}
}
PointSET.java
public class KdTree {
private static class Node {
private boolean isVertical;
private Point2D p;
// the left/bottom subtree
private Node lb;
// the right/top subtree
private Node rt;
public Node(Point2D p, boolean isVertical) {
this.p = p;
this.isVertical = isVertical;
this.lb = null;
this.rt = null;
}
}
private int size;
private Node root;
private final RectHV RECT = new RectHV(0, 0, 1, 1);
// construct an empty set of points
public KdTree() {
this.root = null;
this.size = 0;
}
// is the set empty?
public boolean isEmpty() {
return this.size == 0;
}
// number of points in the set
public int size() {
return this.size;
}
// add the point to the set (if it is not already in the set)
public void insert(Point2D p) {
if (p == null) throw new java.lang.NullPointerException();
this.root = insert(root, p, true);
}
private Node insert(Node node, Point2D p, boolean isVertical) {
// if it is not in the set, create new node
if (node == null) {
size++;
return new Node(p, isVertical);
}
// already in, return it
if (node.p.equals(p))
return node;
// else insert it
if (node.isVertical && p.x() < node.p.x() || !node.isVertical && p.y() < node.p.y()) {
node.lb = insert(node.lb, p, !node.isVertical);
} else {
node.rt = insert(node.rt, p, !node.isVertical);
}
return node;
}
// does the set contain point p?
public boolean contains(Point2D p) {
if (p ==null) return false;
return contains(root, p, false);
}
private boolean contains(Node node, Point2D p, boolean orientation) {
int cmp = p.compareTo(node.p);
if (cmp < 0) return contains(node.lb, p, !orientation);
else if (cmp > 0) return contains(node.rt, p, !orientation);
else return true;
}
// draw all points to standard draw
public void draw() {
StdDraw.setScale(0, 1);
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius();
RECT.draw();
draw(root, RECT);
}
private void draw(Node node, RectHV rect) {
if (node == null) return;
// drawing the points
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius(0.01);
StdDraw.point(node.p.x(), node.p.y());
// drawing the splitting lines
StdDraw.setPenRadius();
if (node.isVertical) {
// vertical
StdDraw.setPenColor(StdDraw.RED);
StdDraw.line(node.p.x(), rect.ymin(), node.p.x(), rect.ymax());
}
else {
// horizontal
StdDraw.setPenColor(StdDraw.BLUE);
StdDraw.line(rect.xmin(), node.p.y(), rect.xmax(), node.p.y());
}
// recursively draw children
draw(node.lb, leftRect(rect, node));
draw(node.rt, rightRect(rect, node));
}
private RectHV leftRect(final RectHV rect, final Node node) {
if (node.isVertical) {
return new RectHV(rect.xmin(), rect.ymin(), node.p.x(), rect.ymax());
} else {
return new RectHV(rect.xmin(), rect.ymin(), rect.xmax(), node.p.y());
}
}
private RectHV rightRect(final RectHV rect, final Node node) {
if (node.isVertical) {
return new RectHV(node.p.x(), rect.ymin(), rect.xmax(), rect.ymax());
} else {
return new RectHV(rect.xmin(), node.p.y(), rect.xmax(), rect.ymax());
}
}
// all points that are inside the rectangle
public Iterable<Point2D> range(RectHV rect) {
final TreeSet<Point2D> rangeSet = new TreeSet<Point2D>();
range(root, RECT, rect, rangeSet);
return rangeSet;
}
private void range(final Node node, final RectHV qrect, final RectHV rect, final TreeSet<Point2D> rangeSet) {
if (node == null) return;
if (rect.intersects(qrect)) { // if query rect is in rectangle
final Point2D p = new Point2D(node.p.x(), node.p.y());
if (rect.contains(p)) rangeSet.add(p); // find the point
// pruning rule
if (node.isVertical) {
if (qrect.xmax() < p.x())
range(node.lb, leftRect(qrect, node), rect, rangeSet);
if (qrect.xmax() > p.x())
range(node.rt, rightRect(qrect, node), rect, rangeSet);
if (qrect.contains(p)) {
range(node.lb, leftRect(qrect, node), rect, rangeSet);
range(node.rt, rightRect(qrect, node), rect, rangeSet);
}
}
else {
if (qrect.ymax() < p.y())
range(node.lb, leftRect(qrect, node), rect, rangeSet);
if (qrect.ymax() > p.y())
range(node.rt, rightRect(qrect, node), rect, rangeSet);
if (qrect.contains(p)) {
range(node.lb, leftRect(qrect, node), rect, rangeSet);
range(node.rt, rightRect(qrect, node), rect, rangeSet);
}
}
}
}
// a nearest neighbor in the set to point p; null if the set is empty
public Point2D nearest(Point2D p) {
if (root == null) return null;
Point2D retp = null;
double mindis = Double.MAX_VALUE;
Queue<Node> queue = new Queue<Node>();
queue.enqueue(root);
while (!queue.isEmpty()) {
Node x = queue.dequeue();
double dis = p.distanceSquaredTo(x.p);
if (dis < mindis) {
retp = x.p;
mindis = dis;
}
if (x.lb != null && x.lb.p.distanceSquaredTo(p) < mindis)
queue.enqueue(x.lb);
if (x.rt != null && x.rt.p.distanceSquaredTo(p) < mindis)
queue.enqueue(x.rt);
}
return retp;
}
private Point2D nearest(final Node node, final RectHV rect, final Point2D p, Point2D candidate) {
if (node == null) return candidate;
double dqn = 0.0;
double drq = 0.0;
RectHV leftRect = null;
RectHV rigtRect = null;
final Point2D query = new Point2D(p.x(), p.y());
if (candidate != null) {
dqn = query.distanceSquaredTo(candidate);
drq = rect.distanceSquaredTo(query);
}
if (candidate == null || dqn > drq) {
final Point2D point = new Point2D(node.p.x(), node.p.y());
if (candidate == null || dqn > query.distanceSquaredTo(point))
candidate = point;
if (node.isVertical) {
// only p.x() changes
leftRect = new RectHV(rect.xmin(), rect.ymin(), node.p.x(), rect.ymax());
rigtRect = new RectHV(node.p.x(), rect.ymin(), rect.xmax(), rect.ymax());
if (p.x() < node.p.x()) {
candidate = nearest(node.lb, leftRect, p, candidate);
candidate = nearest(node.rt, rigtRect, p, candidate);
}
else {
candidate = nearest(node.rt, rigtRect, p, candidate);
candidate = nearest(node.lb, leftRect, p, candidate);
}
}
else {
// only p.y() changes
leftRect = new RectHV(rect.xmin(), rect.ymin(), rect.xmax(), node.p.y());
rigtRect = new RectHV(rect.xmin(), node.p.y(), rect.xmax(), rect.ymax());
if (p.y() < node.p.y()) {
candidate = nearest(node.lb, leftRect, p, candidate);
candidate = nearest(node.rt, rigtRect, p, candidate);
} else {
candidate = nearest(node.rt, rigtRect, p, candidate);
candidate = nearest(node.lb, leftRect, p, candidate);
}
}
}
return candidate;
}
}