From 55a83eddee165f58d01b3bd47a60e8c2c1701caf Mon Sep 17 00:00:00 2001 From: josidd <joseph.siddons@noc.ac.uk> Date: Thu, 3 Oct 2024 15:48:52 +0100 Subject: [PATCH] fix(kdtree): increment split index if next index is above median --- GeoSpatialTools/kdtree.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py index f99ea90..dba22f1 100644 --- a/GeoSpatialTools/kdtree.py +++ b/GeoSpatialTools/kdtree.py @@ -48,18 +48,32 @@ class KDTree: points.sort(key=lambda p: getattr(p, self.variable)) split_index = n_points // 2 + self.partition_value = getattr(points[split_index - 1], self.variable) + while ( + split_index < n_points + and getattr(points[split_index], self.variable) + == self.partition_value + ): + split_index += 1 - self.partition_value = getattr(points[split_index], self.variable) self.split = True + # Left is <= median self.child_left = KDTree(points[:split_index], depth + 1) + # Right is > median self.child_right = KDTree(points[split_index:], depth + 1) return None def insert(self, point: Record) -> bool: - """Insert a Record into the KDTree. May unbalance the KDTree""" + """ + Insert a Record into the KDTree. May unbalance the KDTree. + + The point will not be inserted if it is already in the KDTree. + """ if not self.split: + if point in self.points: + return False self.points.append(point) return True @@ -83,7 +97,7 @@ class KDTree: return self.child_right.delete(point) def query(self, point) -> tuple[Record | None, float]: - """Find the nearest Record within the KDTree to a _query Record""" + """Find the nearest Record within the KDTree to a query Record""" if point.lon < 0: point2 = Record(point.lon + 360, point.lat) else: -- GitLab