diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py index f99ea9069c312e061fb61e89d2a5dbd3005a0169..dba22f1d5980190ad4ac468fe0562b23dcd4fc56 100644 --- a/GeoSpatialTools/kdtree.py +++ b/GeoSpatialTools/kdtree.py @@ -48,18 +48,32 @@ class KDTree: points.sort(key=lambda p: getattr(p, self.variable)) split_index = n_points // 2 + self.partition_value = getattr(points[split_index - 1], self.variable) + while ( + split_index < n_points + and getattr(points[split_index], self.variable) + == self.partition_value + ): + split_index += 1 - self.partition_value = getattr(points[split_index], self.variable) self.split = True + # Left is <= median self.child_left = KDTree(points[:split_index], depth + 1) + # Right is > median self.child_right = KDTree(points[split_index:], depth + 1) return None def insert(self, point: Record) -> bool: - """Insert a Record into the KDTree. May unbalance the KDTree""" + """ + Insert a Record into the KDTree. May unbalance the KDTree. + + The point will not be inserted if it is already in the KDTree. + """ if not self.split: + if point in self.points: + return False self.points.append(point) return True @@ -83,7 +97,7 @@ class KDTree: return self.child_right.delete(point) def query(self, point) -> tuple[Record | None, float]: - """Find the nearest Record within the KDTree to a _query Record""" + """Find the nearest Record within the KDTree to a query Record""" if point.lon < 0: point2 = Record(point.lon + 360, point.lat) else: