From 9496c161b42e133c592bf7986287f8c5b24e1b9c Mon Sep 17 00:00:00 2001
From: josidd <joseph.siddons@noc.ac.uk>
Date: Thu, 3 Oct 2024 10:37:28 +0100
Subject: [PATCH] fix(kdtree): account for lon wrapping

---
 GeoSpatialTools/kdtree.py | 26 ++++++++++++++++++++------
 1 file changed, 20 insertions(+), 6 deletions(-)

diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py
index cd54c1c..f99ea90 100644
--- a/GeoSpatialTools/kdtree.py
+++ b/GeoSpatialTools/kdtree.py
@@ -82,13 +82,27 @@ class KDTree:
         else:
             return self.child_right.delete(point)
 
-    def query(
+    def query(self, point) -> tuple[Record | None, float]:
+        """Find the nearest Record within the KDTree to a _query Record"""
+        if point.lon < 0:
+            point2 = Record(point.lon + 360, point.lat)
+        else:
+            point2 = Record(point.lon - 360, point.lat)
+
+        r1, d1 = self._query(point)
+        r2, d2 = self._query(point2)
+        if d1 <= d2:
+            r = r1
+        else:
+            r = r2
+        return r, point.distance(r)
+
+    def _query(
         self,
         point: Record,
         current_best: Record | None = None,
         best_distance: float = inf,
     ) -> tuple[Record | None, float]:
-        """Find the nearest Record within the KDTree to a query Record"""
         if not self.split:
             for p in self.points:
                 dist = point.distance(p)
@@ -98,25 +112,25 @@ class KDTree:
             return current_best, best_distance
 
         if getattr(point, self.variable) < self.partition_value:
-            current_best, best_distance = self.child_left.query(
+            current_best, best_distance = self.child_left._query(
                 point, current_best, best_distance
             )
             if (
                 point.distance(self._get_partition_record(point))
                 < best_distance
             ):
-                current_best, best_distance = self.child_right.query(
+                current_best, best_distance = self.child_right._query(
                     point, current_best, best_distance
                 )
         else:
-            current_best, best_distance = self.child_right.query(
+            current_best, best_distance = self.child_right._query(
                 point, current_best, best_distance
             )
             if (
                 point.distance(self._get_partition_record(point))
                 < best_distance
             ):
-                current_best, best_distance = self.child_left.query(
+                current_best, best_distance = self.child_left._query(
                     point, current_best, best_distance
                 )
 
-- 
GitLab