From 01f7c77ea2a9c05810096fc1ae1dd444da463cce Mon Sep 17 00:00:00 2001
From: josidd <joseph.siddons@noc.ac.uk>
Date: Fri, 4 Oct 2024 08:32:14 +0100
Subject: [PATCH] feat(kdtree_query): Querying KDTree now returns list of
 neighbours

---
 GeoSpatialTools/kdtree.py | 23 +++++++++++++----------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py
index dba22f1..467131f 100644
--- a/GeoSpatialTools/kdtree.py
+++ b/GeoSpatialTools/kdtree.py
@@ -96,7 +96,7 @@ class KDTree:
         else:
             return self.child_right.delete(point)
 
-    def query(self, point) -> tuple[Record | None, float]:
+    def query(self, point) -> tuple[list[Record], float]:
         """Find the nearest Record within the KDTree to a query Record"""
         if point.lon < 0:
             point2 = Record(point.lon + 360, point.lat)
@@ -106,32 +106,35 @@ class KDTree:
         r1, d1 = self._query(point)
         r2, d2 = self._query(point2)
         if d1 <= d2:
-            r = r1
+            return r1, d1
         else:
-            r = r2
-        return r, point.distance(r)
+            return r2, d2
 
     def _query(
         self,
         point: Record,
-        current_best: Record | None = None,
+        current_best: list[Record] | None = None,
         best_distance: float = inf,
-    ) -> tuple[Record | None, float]:
+    ) -> tuple[list[Record], float]:
+        if current_best is None:
+            current_best = list()
         if not self.split:
             for p in self.points:
                 dist = point.distance(p)
                 if dist < best_distance:
-                    current_best = p
+                    current_best = [p]
                     best_distance = dist
+                elif dist == best_distance:
+                    current_best.append(p)
             return current_best, best_distance
 
-        if getattr(point, self.variable) < self.partition_value:
+        if getattr(point, self.variable) <= self.partition_value:
             current_best, best_distance = self.child_left._query(
                 point, current_best, best_distance
             )
             if (
                 point.distance(self._get_partition_record(point))
-                < best_distance
+                <= best_distance
             ):
                 current_best, best_distance = self.child_right._query(
                     point, current_best, best_distance
@@ -142,7 +145,7 @@ class KDTree:
             )
             if (
                 point.distance(self._get_partition_record(point))
-                < best_distance
+                <= best_distance
             ):
                 current_best, best_distance = self.child_left._query(
                     point, current_best, best_distance
-- 
GitLab