Commit 01f7c77e authored by Joseph Siddons's avatar Joseph Siddons
Browse files

feat(kdtree_query): Querying KDTree now returns list of neighbours

parent f9bc5809
...@@ -96,7 +96,7 @@ class KDTree: ...@@ -96,7 +96,7 @@ class KDTree:
else: else:
return self.child_right.delete(point) return self.child_right.delete(point)
def query(self, point) -> tuple[Record | None, float]: def query(self, point) -> tuple[list[Record], float]:
"""Find the nearest Record within the KDTree to a query Record""" """Find the nearest Record within the KDTree to a query Record"""
if point.lon < 0: if point.lon < 0:
point2 = Record(point.lon + 360, point.lat) point2 = Record(point.lon + 360, point.lat)
...@@ -106,32 +106,35 @@ class KDTree: ...@@ -106,32 +106,35 @@ class KDTree:
r1, d1 = self._query(point) r1, d1 = self._query(point)
r2, d2 = self._query(point2) r2, d2 = self._query(point2)
if d1 <= d2: if d1 <= d2:
r = r1 return r1, d1
else: else:
r = r2 return r2, d2
return r, point.distance(r)
def _query( def _query(
self, self,
point: Record, point: Record,
current_best: Record | None = None, current_best: list[Record] | None = None,
best_distance: float = inf, best_distance: float = inf,
) -> tuple[Record | None, float]: ) -> tuple[list[Record], float]:
if current_best is None:
current_best = list()
if not self.split: if not self.split:
for p in self.points: for p in self.points:
dist = point.distance(p) dist = point.distance(p)
if dist < best_distance: if dist < best_distance:
current_best = p current_best = [p]
best_distance = dist best_distance = dist
elif dist == best_distance:
current_best.append(p)
return current_best, best_distance return current_best, best_distance
if getattr(point, self.variable) < self.partition_value: if getattr(point, self.variable) <= self.partition_value:
current_best, best_distance = self.child_left._query( current_best, best_distance = self.child_left._query(
point, current_best, best_distance point, current_best, best_distance
) )
if ( if (
point.distance(self._get_partition_record(point)) point.distance(self._get_partition_record(point))
< best_distance <= best_distance
): ):
current_best, best_distance = self.child_right._query( current_best, best_distance = self.child_right._query(
point, current_best, best_distance point, current_best, best_distance
...@@ -142,7 +145,7 @@ class KDTree: ...@@ -142,7 +145,7 @@ class KDTree:
) )
if ( if (
point.distance(self._get_partition_record(point)) point.distance(self._get_partition_record(point))
< best_distance <= best_distance
): ):
current_best, best_distance = self.child_left._query( current_best, best_distance = self.child_left._query(
point, current_best, best_distance point, current_best, best_distance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment