From d58f82b9fc1433efb614f9d100b63dbc0eb13f7e Mon Sep 17 00:00:00 2001
From: josidd <joseph.siddons@noc.ac.uk>
Date: Fri, 4 Oct 2024 08:34:49 +0100
Subject: [PATCH] fix: revert median partition change, split on index. Account
 for this in insert and delete methods.

---
 GeoSpatialTools/kdtree.py | 46 ++++++++++++++++++++++++++++-----------
 1 file changed, 33 insertions(+), 13 deletions(-)

diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py
index 467131f..c77524a 100644
--- a/GeoSpatialTools/kdtree.py
+++ b/GeoSpatialTools/kdtree.py
@@ -49,18 +49,18 @@ class KDTree:
         points.sort(key=lambda p: getattr(p, self.variable))
         split_index = n_points // 2
         self.partition_value = getattr(points[split_index - 1], self.variable)
-        while (
-            split_index < n_points
-            and getattr(points[split_index], self.variable)
-            == self.partition_value
-        ):
-            split_index += 1
+        # while (
+        #     split_index < n_points
+        #     and getattr(points[split_index], self.variable)
+        #     == self.partition_value
+        # ):
+        #     split_index += 1
 
         self.split = True
 
-        # Left is <= median
+        # Left is points left of midpoint
         self.child_left = KDTree(points[:split_index], depth + 1)
-        # Right is > median
+        # Right is points right of midpoint
         self.child_right = KDTree(points[split_index:], depth + 1)
 
         return None
@@ -79,8 +79,25 @@ class KDTree:
 
         if getattr(point, self.variable) < self.partition_value:
             return self.child_left.insert(point)
-        else:
+        elif getattr(point, self.variable) > self.partition_value:
             return self.child_right.insert(point)
+        else:
+            r, _ = self.query(point)
+            if point in r:
+                return False
+            self.child_left._insert(point)
+            return True
+
+    def _insert(self, point: Record) -> None:
+        """Insert a point even if it already exists in the KDTree"""
+        if not self.split:
+            self.points.append(point)
+            return
+        if getattr(point, self.variable) <= self.partition_value:
+            self.child_left._insert(point)
+        else:
+            self.child_right._insert(point)
+        return
 
     def delete(self, point: Record) -> bool:
         """Delete a Record from the KDTree. May unbalance the KDTree"""
@@ -91,10 +108,13 @@ class KDTree:
             except ValueError:
                 return False
 
-        if getattr(point, self.variable) < self.partition_value:
-            return self.child_left.delete(point)
-        else:
-            return self.child_right.delete(point)
+        if getattr(point, self.variable) <= self.partition_value:
+            if self.child_left.delete(point):
+                return True
+        if getattr(point, self.variable) >= self.partition_value:
+            if self.child_right.delete(point):
+                return True
+        return False
 
     def query(self, point) -> tuple[list[Record], float]:
         """Find the nearest Record within the KDTree to a query Record"""
-- 
GitLab