From 55a83eddee165f58d01b3bd47a60e8c2c1701caf Mon Sep 17 00:00:00 2001
From: josidd <joseph.siddons@noc.ac.uk>
Date: Thu, 3 Oct 2024 15:48:52 +0100
Subject: [PATCH] fix(kdtree): increment split index if next index is above
 median

---
 GeoSpatialTools/kdtree.py | 20 +++++++++++++++++---
 1 file changed, 17 insertions(+), 3 deletions(-)

diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py
index f99ea90..dba22f1 100644
--- a/GeoSpatialTools/kdtree.py
+++ b/GeoSpatialTools/kdtree.py
@@ -48,18 +48,32 @@ class KDTree:
 
         points.sort(key=lambda p: getattr(p, self.variable))
         split_index = n_points // 2
+        self.partition_value = getattr(points[split_index - 1], self.variable)
+        while (
+            split_index < n_points
+            and getattr(points[split_index], self.variable)
+            == self.partition_value
+        ):
+            split_index += 1
 
-        self.partition_value = getattr(points[split_index], self.variable)
         self.split = True
 
+        # Left is <= median
         self.child_left = KDTree(points[:split_index], depth + 1)
+        # Right is > median
         self.child_right = KDTree(points[split_index:], depth + 1)
 
         return None
 
     def insert(self, point: Record) -> bool:
-        """Insert a Record into the KDTree. May unbalance the KDTree"""
+        """
+        Insert a Record into the KDTree. May unbalance the KDTree.
+
+        The point will not be inserted if it is already in the KDTree.
+        """
         if not self.split:
+            if point in self.points:
+                return False
             self.points.append(point)
             return True
 
@@ -83,7 +97,7 @@ class KDTree:
             return self.child_right.delete(point)
 
     def query(self, point) -> tuple[Record | None, float]:
-        """Find the nearest Record within the KDTree to a _query Record"""
+        """Find the nearest Record within the KDTree to a query Record"""
         if point.lon < 0:
             point2 = Record(point.lon + 360, point.lat)
         else:
-- 
GitLab