diff --git a/GeoSpatialTools/__init__.py b/GeoSpatialTools/__init__.py index 4a2d1ad900512da6b40cadb362fe6f4f92ce34d0..11e1e9388c3b2a90fa76a7649632b55e77c63ef1 100644 --- a/GeoSpatialTools/__init__.py +++ b/GeoSpatialTools/__init__.py @@ -1,4 +1,14 @@ from .neighbours import find_nearest from .distance_metrics import haversine +from .quadtree import Ellipse, QuadTree, Record, Rectangle +from .kdtree import KDTree -__all__ = ["find_nearest", "haversine"] +__all__ = [ + "Ellipse", + "KDTree", + "QuadTree", + "Record", + "Rectangle", + "find_nearest", + "haversine", +] diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py new file mode 100644 index 0000000000000000000000000000000000000000..cd54c1c376b7c5aea2cc62f70ed1b96036402368 --- /dev/null +++ b/GeoSpatialTools/kdtree.py @@ -0,0 +1,128 @@ +""" +An implementation of KDTree using Haversine Distance for GeoSpatial analysis. +Useful tool for quickly searching for nearest neighbours. +""" + +from . import Record +from numpy import inf + + +class KDTree: + """ + A Haverine distance implementation of a balanced KDTree. + + This implementation is a _balanced_ KDTree, each leaf node should have the + same number of points (or differ by 1 depending on the number of points + the KDTree is intialised with). + + The KDTree partitions in each of the lon and lat dimensions alternatively + in sequence by splitting at the median of the dimension of the points + assigned to the branch. + + Parameters + ---------- + points : list[Record] + A list of GeoSpatialTools.Record instances. + depth : int + The current depth of the KDTree, you should set this to 0, it is used + internally. + max_depth : int + The maximium depth of the KDTree. The leaf nodes will have depth no + larger than this value. Leaf nodes will not be created if there is + only 1 point in the branch. + """ + + def __init__( + self, points: list[Record], depth: int = 0, max_depth: int = 20 + ) -> None: + self.depth = depth + n_points = len(points) + + if self.depth == max_depth or n_points < 2: + self.points = points + self.split = False + return None + + self.axis = depth % 2 + self.variable = "lon" if self.axis == 0 else "lat" + + points.sort(key=lambda p: getattr(p, self.variable)) + split_index = n_points // 2 + + self.partition_value = getattr(points[split_index], self.variable) + self.split = True + + self.child_left = KDTree(points[:split_index], depth + 1) + self.child_right = KDTree(points[split_index:], depth + 1) + + return None + + def insert(self, point: Record) -> bool: + """Insert a Record into the KDTree. May unbalance the KDTree""" + if not self.split: + self.points.append(point) + return True + + if getattr(point, self.variable) < self.partition_value: + return self.child_left.insert(point) + else: + return self.child_right.insert(point) + + def delete(self, point: Record) -> bool: + """Delete a Record from the KDTree. May unbalance the KDTree""" + if not self.split: + try: + self.points.remove(point) + return True + except ValueError: + return False + + if getattr(point, self.variable) < self.partition_value: + return self.child_left.delete(point) + else: + return self.child_right.delete(point) + + def query( + self, + point: Record, + current_best: Record | None = None, + best_distance: float = inf, + ) -> tuple[Record | None, float]: + """Find the nearest Record within the KDTree to a query Record""" + if not self.split: + for p in self.points: + dist = point.distance(p) + if dist < best_distance: + current_best = p + best_distance = dist + return current_best, best_distance + + if getattr(point, self.variable) < self.partition_value: + current_best, best_distance = self.child_left.query( + point, current_best, best_distance + ) + if ( + point.distance(self._get_partition_record(point)) + < best_distance + ): + current_best, best_distance = self.child_right.query( + point, current_best, best_distance + ) + else: + current_best, best_distance = self.child_right.query( + point, current_best, best_distance + ) + if ( + point.distance(self._get_partition_record(point)) + < best_distance + ): + current_best, best_distance = self.child_left.query( + point, current_best, best_distance + ) + + return current_best, best_distance + + def _get_partition_record(self, point: Record) -> Record: + if self.variable == "lon": + return Record(self.partition_value, point.lat) + return Record(point.lon, self.partition_value) diff --git a/test/test_kdtree.py b/test/test_kdtree.py new file mode 100644 index 0000000000000000000000000000000000000000..63a4ceb24761279eb261ab65faa921b73a8d6888 --- /dev/null +++ b/test/test_kdtree.py @@ -0,0 +1,40 @@ +import unittest +from numpy import min, argmin +from GeoSpatialTools import haversine, KDTree, Record + + +class TestKDTree(unittest.TestCase): + records = [ + Record(1, 2, uid="A"), + Record(-9, 44, uid="B"), + Record(174, -81, uid="C"), + Record(-4, 71, uid="D"), + ] + + def test_insert(self): + kt = KDTree(self.records) + test_record = Record(175, 44) + assert kt.insert(test_record) + assert test_record in kt.child_right.child_right.points + + def test_delete(self): + kt = KDTree(self.records) + delete_rec = self.records[2] + assert delete_rec in kt.child_right.child_right.points + assert kt.delete(delete_rec) + assert delete_rec not in kt.child_right.child_right.points + + def test_query(self): + kt = KDTree(self.records) + test_record = Record(-6, 35) + best_record, best_dist = kt.query(test_record) + true_dist = min([test_record.distance(r) for r in self.records]) + true_ind = argmin([test_record.distance(r) for r in self.records]) + true_record = self.records[true_ind] + + self.assertAlmostEqual(true_dist, best_dist) + assert best_record == true_record + + +if __name__ == "__main__": + unittest.main()