From d58f82b9fc1433efb614f9d100b63dbc0eb13f7e Mon Sep 17 00:00:00 2001 From: josidd <joseph.siddons@noc.ac.uk> Date: Fri, 4 Oct 2024 08:34:49 +0100 Subject: [PATCH] fix: revert median partition change, split on index. Account for this in insert and delete methods. --- GeoSpatialTools/kdtree.py | 46 ++++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py index 467131f..c77524a 100644 --- a/GeoSpatialTools/kdtree.py +++ b/GeoSpatialTools/kdtree.py @@ -49,18 +49,18 @@ class KDTree: points.sort(key=lambda p: getattr(p, self.variable)) split_index = n_points // 2 self.partition_value = getattr(points[split_index - 1], self.variable) - while ( - split_index < n_points - and getattr(points[split_index], self.variable) - == self.partition_value - ): - split_index += 1 + # while ( + # split_index < n_points + # and getattr(points[split_index], self.variable) + # == self.partition_value + # ): + # split_index += 1 self.split = True - # Left is <= median + # Left is points left of midpoint self.child_left = KDTree(points[:split_index], depth + 1) - # Right is > median + # Right is points right of midpoint self.child_right = KDTree(points[split_index:], depth + 1) return None @@ -79,8 +79,25 @@ class KDTree: if getattr(point, self.variable) < self.partition_value: return self.child_left.insert(point) - else: + elif getattr(point, self.variable) > self.partition_value: return self.child_right.insert(point) + else: + r, _ = self.query(point) + if point in r: + return False + self.child_left._insert(point) + return True + + def _insert(self, point: Record) -> None: + """Insert a point even if it already exists in the KDTree""" + if not self.split: + self.points.append(point) + return + if getattr(point, self.variable) <= self.partition_value: + self.child_left._insert(point) + else: + self.child_right._insert(point) + return def delete(self, point: Record) -> bool: """Delete a Record from the KDTree. May unbalance the KDTree""" @@ -91,10 +108,13 @@ class KDTree: except ValueError: return False - if getattr(point, self.variable) < self.partition_value: - return self.child_left.delete(point) - else: - return self.child_right.delete(point) + if getattr(point, self.variable) <= self.partition_value: + if self.child_left.delete(point): + return True + if getattr(point, self.variable) >= self.partition_value: + if self.child_right.delete(point): + return True + return False def query(self, point) -> tuple[list[Record], float]: """Find the nearest Record within the KDTree to a query Record""" -- GitLab