diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py index 467131fda8fdcd5d371219019dc1b78ace083330..c77524ad3a9d27f2d756cd17c292bfd330740793 100644 --- a/GeoSpatialTools/kdtree.py +++ b/GeoSpatialTools/kdtree.py @@ -49,18 +49,18 @@ class KDTree: points.sort(key=lambda p: getattr(p, self.variable)) split_index = n_points // 2 self.partition_value = getattr(points[split_index - 1], self.variable) - while ( - split_index < n_points - and getattr(points[split_index], self.variable) - == self.partition_value - ): - split_index += 1 + # while ( + # split_index < n_points + # and getattr(points[split_index], self.variable) + # == self.partition_value + # ): + # split_index += 1 self.split = True - # Left is <= median + # Left is points left of midpoint self.child_left = KDTree(points[:split_index], depth + 1) - # Right is > median + # Right is points right of midpoint self.child_right = KDTree(points[split_index:], depth + 1) return None @@ -79,8 +79,25 @@ class KDTree: if getattr(point, self.variable) < self.partition_value: return self.child_left.insert(point) - else: + elif getattr(point, self.variable) > self.partition_value: return self.child_right.insert(point) + else: + r, _ = self.query(point) + if point in r: + return False + self.child_left._insert(point) + return True + + def _insert(self, point: Record) -> None: + """Insert a point even if it already exists in the KDTree""" + if not self.split: + self.points.append(point) + return + if getattr(point, self.variable) <= self.partition_value: + self.child_left._insert(point) + else: + self.child_right._insert(point) + return def delete(self, point: Record) -> bool: """Delete a Record from the KDTree. May unbalance the KDTree""" @@ -91,10 +108,13 @@ class KDTree: except ValueError: return False - if getattr(point, self.variable) < self.partition_value: - return self.child_left.delete(point) - else: - return self.child_right.delete(point) + if getattr(point, self.variable) <= self.partition_value: + if self.child_left.delete(point): + return True + if getattr(point, self.variable) >= self.partition_value: + if self.child_right.delete(point): + return True + return False def query(self, point) -> tuple[list[Record], float]: """Find the nearest Record within the KDTree to a query Record"""