diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py index dba22f1d5980190ad4ac468fe0562b23dcd4fc56..467131fda8fdcd5d371219019dc1b78ace083330 100644 --- a/GeoSpatialTools/kdtree.py +++ b/GeoSpatialTools/kdtree.py @@ -96,7 +96,7 @@ class KDTree: else: return self.child_right.delete(point) - def query(self, point) -> tuple[Record | None, float]: + def query(self, point) -> tuple[list[Record], float]: """Find the nearest Record within the KDTree to a query Record""" if point.lon < 0: point2 = Record(point.lon + 360, point.lat) @@ -106,32 +106,35 @@ class KDTree: r1, d1 = self._query(point) r2, d2 = self._query(point2) if d1 <= d2: - r = r1 + return r1, d1 else: - r = r2 - return r, point.distance(r) + return r2, d2 def _query( self, point: Record, - current_best: Record | None = None, + current_best: list[Record] | None = None, best_distance: float = inf, - ) -> tuple[Record | None, float]: + ) -> tuple[list[Record], float]: + if current_best is None: + current_best = list() if not self.split: for p in self.points: dist = point.distance(p) if dist < best_distance: - current_best = p + current_best = [p] best_distance = dist + elif dist == best_distance: + current_best.append(p) return current_best, best_distance - if getattr(point, self.variable) < self.partition_value: + if getattr(point, self.variable) <= self.partition_value: current_best, best_distance = self.child_left._query( point, current_best, best_distance ) if ( point.distance(self._get_partition_record(point)) - < best_distance + <= best_distance ): current_best, best_distance = self.child_right._query( point, current_best, best_distance @@ -142,7 +145,7 @@ class KDTree: ) if ( point.distance(self._get_partition_record(point)) - < best_distance + <= best_distance ): current_best, best_distance = self.child_left._query( point, current_best, best_distance