Commit 01f7c77e authored by Joseph Siddons's avatar Joseph Siddons
Browse files

feat(kdtree_query): Querying KDTree now returns list of neighbours

parent f9bc5809
......@@ -96,7 +96,7 @@ class KDTree:
else:
return self.child_right.delete(point)
def query(self, point) -> tuple[Record | None, float]:
def query(self, point) -> tuple[list[Record], float]:
"""Find the nearest Record within the KDTree to a query Record"""
if point.lon < 0:
point2 = Record(point.lon + 360, point.lat)
......@@ -106,32 +106,35 @@ class KDTree:
r1, d1 = self._query(point)
r2, d2 = self._query(point2)
if d1 <= d2:
r = r1
return r1, d1
else:
r = r2
return r, point.distance(r)
return r2, d2
def _query(
self,
point: Record,
current_best: Record | None = None,
current_best: list[Record] | None = None,
best_distance: float = inf,
) -> tuple[Record | None, float]:
) -> tuple[list[Record], float]:
if current_best is None:
current_best = list()
if not self.split:
for p in self.points:
dist = point.distance(p)
if dist < best_distance:
current_best = p
current_best = [p]
best_distance = dist
elif dist == best_distance:
current_best.append(p)
return current_best, best_distance
if getattr(point, self.variable) < self.partition_value:
if getattr(point, self.variable) <= self.partition_value:
current_best, best_distance = self.child_left._query(
point, current_best, best_distance
)
if (
point.distance(self._get_partition_record(point))
< best_distance
<= best_distance
):
current_best, best_distance = self.child_right._query(
point, current_best, best_distance
......@@ -142,7 +145,7 @@ class KDTree:
)
if (
point.distance(self._get_partition_record(point))
< best_distance
<= best_distance
):
current_best, best_distance = self.child_left._query(
point, current_best, best_distance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment