From 01f7c77ea2a9c05810096fc1ae1dd444da463cce Mon Sep 17 00:00:00 2001
From: josidd <joseph.siddons@noc.ac.uk>
Date: Fri, 4 Oct 2024 08:32:14 +0100
Subject: [PATCH 1/4] feat(kdtree_query): Querying KDTree now returns list of
 neighbours

---
 GeoSpatialTools/kdtree.py | 23 +++++++++++++----------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py
index dba22f1..467131f 100644
--- a/GeoSpatialTools/kdtree.py
+++ b/GeoSpatialTools/kdtree.py
@@ -96,7 +96,7 @@ class KDTree:
         else:
             return self.child_right.delete(point)
 
-    def query(self, point) -> tuple[Record | None, float]:
+    def query(self, point) -> tuple[list[Record], float]:
         """Find the nearest Record within the KDTree to a query Record"""
         if point.lon < 0:
             point2 = Record(point.lon + 360, point.lat)
@@ -106,32 +106,35 @@ class KDTree:
         r1, d1 = self._query(point)
         r2, d2 = self._query(point2)
         if d1 <= d2:
-            r = r1
+            return r1, d1
         else:
-            r = r2
-        return r, point.distance(r)
+            return r2, d2
 
     def _query(
         self,
         point: Record,
-        current_best: Record | None = None,
+        current_best: list[Record] | None = None,
         best_distance: float = inf,
-    ) -> tuple[Record | None, float]:
+    ) -> tuple[list[Record], float]:
+        if current_best is None:
+            current_best = list()
         if not self.split:
             for p in self.points:
                 dist = point.distance(p)
                 if dist < best_distance:
-                    current_best = p
+                    current_best = [p]
                     best_distance = dist
+                elif dist == best_distance:
+                    current_best.append(p)
             return current_best, best_distance
 
-        if getattr(point, self.variable) < self.partition_value:
+        if getattr(point, self.variable) <= self.partition_value:
             current_best, best_distance = self.child_left._query(
                 point, current_best, best_distance
             )
             if (
                 point.distance(self._get_partition_record(point))
-                < best_distance
+                <= best_distance
             ):
                 current_best, best_distance = self.child_right._query(
                     point, current_best, best_distance
@@ -142,7 +145,7 @@ class KDTree:
             )
             if (
                 point.distance(self._get_partition_record(point))
-                < best_distance
+                <= best_distance
             ):
                 current_best, best_distance = self.child_left._query(
                     point, current_best, best_distance
-- 
GitLab


From 1836b2c48d06d8deedf79babafe56db2247719d6 Mon Sep 17 00:00:00 2001
From: josidd <joseph.siddons@noc.ac.uk>
Date: Fri, 4 Oct 2024 08:33:30 +0100
Subject: [PATCH 2/4] test(kdtree): Additional tests for KDTree edge cases

+ Test that query handles longitude wrap
+ Tests for duplicate positions
+ Test insert / delete duplicates
---
 test/test_kdtree.py | 64 ++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 63 insertions(+), 1 deletion(-)

diff --git a/test/test_kdtree.py b/test/test_kdtree.py
index 63a4ceb..85631bf 100644
--- a/test/test_kdtree.py
+++ b/test/test_kdtree.py
@@ -24,6 +24,18 @@ class TestKDTree(unittest.TestCase):
         assert kt.delete(delete_rec)
         assert delete_rec not in kt.child_right.child_right.points
 
+    def test_delete_dup(self):
+        test_records = [
+            Record(45, -23, uid="1"),
+            Record(45, -23, uid="2"),
+            Record(45, -23, uid="3"),
+            Record(45, -23, uid="4"),
+        ]
+        kt = KDTree(test_records, max_depth=3)
+        assert kt.delete(test_records[1])
+        # TEST: Cannot delete same record twice!
+        assert not kt.delete(test_records[1])
+
     def test_query(self):
         kt = KDTree(self.records)
         test_record = Record(-6, 35)
@@ -33,7 +45,57 @@ class TestKDTree(unittest.TestCase):
         true_record = self.records[true_ind]
 
         self.assertAlmostEqual(true_dist, best_dist)
-        assert best_record == true_record
+        assert len(best_record) == 1
+        assert best_record[0] == true_record
+
+    def test_duplicated_pos(self):
+        # TEST: That equal records get partitioned equally
+        test_records = [
+            Record(45, -23, uid="1"),
+            Record(45, -23, uid="2"),
+            Record(45, -23, uid="3"),
+            Record(45, -23, uid="4"),
+        ]
+        kt = KDTree(test_records, max_depth=3)
+        assert len(kt.child_left.child_left.points) == 1
+        assert len(kt.child_left.child_right.points) == 1
+        assert len(kt.child_right.child_left.points) == 1
+        assert len(kt.child_right.child_right.points) == 1
+
+    def test_insert_dup(self):
+        test_records = [
+            Record(45, -23, uid="1"),
+            Record(45, -23, uid="2"),
+            Record(45, -23, uid="3"),
+            Record(45, -23, uid="4"),
+        ]
+        kt = KDTree(test_records, max_depth=3)
+        assert not kt.insert(test_records[0])
+        assert not kt.insert(test_records[1])
+        assert not kt.insert(test_records[2])
+        assert not kt.insert(test_records[3])
+        assert kt.insert(Record(45, -23, uid="5"))
+        assert not kt.insert(Record(45, -23, uid="5"))
+        # TEST: Can insert after deleting
+        assert kt.delete(Record(45, -23, uid="5"))
+        assert kt.insert(Record(45, -23, uid="5"))
+
+    def test_get_multiple_neighbours(self):
+        kt = KDTree(self.records)
+        kt.insert(Record(45, -21, uid="1"))
+        kt.insert(Record(45, -21, uid="2"))
+
+        r, d = kt.query(Record(44, -21, uid="3"))
+        assert len(r) == 2
+
+    def test_wrap(self):
+        # TEST: Accounts for wrap at -180, 180
+        kt = KDTree(self.records)
+        kt.insert(Record(-160, -64, uid="G"))
+        query_rec = Record(-178, -79, uid="E")
+        r, _ = kt.query(query_rec)
+        assert len(r) == 1
+        assert r[0].uid == "C"
 
 
 if __name__ == "__main__":
-- 
GitLab


From d58f82b9fc1433efb614f9d100b63dbc0eb13f7e Mon Sep 17 00:00:00 2001
From: josidd <joseph.siddons@noc.ac.uk>
Date: Fri, 4 Oct 2024 08:34:49 +0100
Subject: [PATCH 3/4] fix: revert median partition change, split on index.
 Account for this in insert and delete methods.

---
 GeoSpatialTools/kdtree.py | 46 ++++++++++++++++++++++++++++-----------
 1 file changed, 33 insertions(+), 13 deletions(-)

diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py
index 467131f..c77524a 100644
--- a/GeoSpatialTools/kdtree.py
+++ b/GeoSpatialTools/kdtree.py
@@ -49,18 +49,18 @@ class KDTree:
         points.sort(key=lambda p: getattr(p, self.variable))
         split_index = n_points // 2
         self.partition_value = getattr(points[split_index - 1], self.variable)
-        while (
-            split_index < n_points
-            and getattr(points[split_index], self.variable)
-            == self.partition_value
-        ):
-            split_index += 1
+        # while (
+        #     split_index < n_points
+        #     and getattr(points[split_index], self.variable)
+        #     == self.partition_value
+        # ):
+        #     split_index += 1
 
         self.split = True
 
-        # Left is <= median
+        # Left is points left of midpoint
         self.child_left = KDTree(points[:split_index], depth + 1)
-        # Right is > median
+        # Right is points right of midpoint
         self.child_right = KDTree(points[split_index:], depth + 1)
 
         return None
@@ -79,8 +79,25 @@ class KDTree:
 
         if getattr(point, self.variable) < self.partition_value:
             return self.child_left.insert(point)
-        else:
+        elif getattr(point, self.variable) > self.partition_value:
             return self.child_right.insert(point)
+        else:
+            r, _ = self.query(point)
+            if point in r:
+                return False
+            self.child_left._insert(point)
+            return True
+
+    def _insert(self, point: Record) -> None:
+        """Insert a point even if it already exists in the KDTree"""
+        if not self.split:
+            self.points.append(point)
+            return
+        if getattr(point, self.variable) <= self.partition_value:
+            self.child_left._insert(point)
+        else:
+            self.child_right._insert(point)
+        return
 
     def delete(self, point: Record) -> bool:
         """Delete a Record from the KDTree. May unbalance the KDTree"""
@@ -91,10 +108,13 @@ class KDTree:
             except ValueError:
                 return False
 
-        if getattr(point, self.variable) < self.partition_value:
-            return self.child_left.delete(point)
-        else:
-            return self.child_right.delete(point)
+        if getattr(point, self.variable) <= self.partition_value:
+            if self.child_left.delete(point):
+                return True
+        if getattr(point, self.variable) >= self.partition_value:
+            if self.child_right.delete(point):
+                return True
+        return False
 
     def query(self, point) -> tuple[list[Record], float]:
         """Find the nearest Record within the KDTree to a query Record"""
-- 
GitLab


From 98f65e454ec0cc17f5d4ed96cc76514596109689 Mon Sep 17 00:00:00 2001
From: josidd <joseph.siddons@noc.ac.uk>
Date: Fri, 4 Oct 2024 08:35:34 +0100
Subject: [PATCH 4/4] chore: update and re-run KDTree notebook.

---
 notebooks/kdtree.ipynb | 52 ++++++++++++++++++++++--------------------
 1 file changed, 27 insertions(+), 25 deletions(-)

diff --git a/notebooks/kdtree.ipynb b/notebooks/kdtree.ipynb
index 9a69a4e..3ef3161 100644
--- a/notebooks/kdtree.ipynb
+++ b/notebooks/kdtree.ipynb
@@ -48,7 +48,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 11,
    "id": "c60b30de-f864-477a-a09a-5f1caa4d9b9a",
    "metadata": {},
    "outputs": [
@@ -56,18 +56,18 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "(14222, 2)\n",
+      "(16000, 2)\n",
       "shape: (5, 2)\n",
       "┌──────┬─────┐\n",
       "│ lon  ┆ lat │\n",
       "│ ---  ┆ --- │\n",
       "│ i64  ┆ i64 │\n",
       "╞══════╪═════╡\n",
-      "│ -30  ┆ -41 │\n",
-      "│ -149 ┆ 56  │\n",
-      "│ 7    ┆ -68 │\n",
-      "│ -48  ┆ 83  │\n",
-      "│ -126 ┆ -35 │\n",
+      "│ 16   ┆ -75 │\n",
+      "│ 144  ┆ -77 │\n",
+      "│ -173 ┆ -83 │\n",
+      "│ 142  ┆ -81 │\n",
+      "│ -50  ┆ -38 │\n",
       "└──────┴─────┘\n"
      ]
     }
@@ -83,14 +83,14 @@
     "# dates_use = dates.sample(N, with_replacement=True).alias(\"datetime\")\n",
     "# uids = pl.Series(\"uid\", [generate_uid(8) for _ in range(N)])\n",
     "\n",
-    "df = pl.DataFrame([lons_use, lats_use]).unique()\n",
+    "df = pl.DataFrame([lons_use, lats_use])\n",
     "print(df.shape)\n",
     "print(df.head())"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 12,
    "id": "875f2a67-49fe-476f-add1-b1d76c6cd8f9",
    "metadata": {},
    "outputs": [],
@@ -100,7 +100,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 13,
    "id": "1e883e5a-5086-4c29-aff2-d308874eae16",
    "metadata": {},
    "outputs": [
@@ -108,8 +108,8 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "CPU times: user 43.5 ms, sys: 3.43 ms, total: 46.9 ms\n",
-      "Wall time: 46.8 ms\n"
+      "CPU times: user 82 ms, sys: 4.14 ms, total: 86.1 ms\n",
+      "Wall time: 84.3 ms\n"
      ]
     }
    ],
@@ -120,7 +120,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 14,
    "id": "69022ad1-5ec8-4a09-836c-273ef452451f",
    "metadata": {},
    "outputs": [
@@ -128,7 +128,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "173 μs ± 1.36 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
+      "188 μs ± 3.45 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
      ]
     }
    ],
@@ -140,7 +140,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 15,
    "id": "28031966-c7d0-4201-a467-37590118e851",
    "metadata": {},
    "outputs": [
@@ -148,7 +148,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "7.71 ms ± 38.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
+      "8.72 ms ± 74.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
      ]
     }
    ],
@@ -160,7 +160,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 16,
    "id": "0d10b2ba-57b2-475c-9d01-135363423990",
    "metadata": {},
    "outputs": [
@@ -168,26 +168,27 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "CPU times: user 15.4 s, sys: 37.8 ms, total: 15.5 s\n",
-      "Wall time: 15.5 s\n"
+      "CPU times: user 17.3 s, sys: 31.6 ms, total: 17.3 s\n",
+      "Wall time: 17.3 s\n"
      ]
     }
    ],
    "source": [
     "%%time\n",
     "n_samples = 1000\n",
+    "tol = 1e-8\n",
     "test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]\n",
     "kd_res = [kt.query(r) for r in test_records]\n",
-    "kd_recs = [_[0] for _ in kd_res]\n",
+    "kd_recs = [_[0][0] for _ in kd_res]\n",
     "kd_dists = [_[1] for _ in kd_res]\n",
     "tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]\n",
     "tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n",
-    "assert kd_dists == tr_dists, \"NOT MATCHING?\""
+    "assert all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]), \"NOT MATCHING?\""
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 17,
    "id": "a6aa6926-7fd5-4fff-bd20-7bc0305b948d",
    "metadata": {},
    "outputs": [
@@ -213,7 +214,7 @@
        "└──────────┴──────────┴─────────┴────────┴────────┴─────────┴────────┴────────┘"
       ]
      },
-     "execution_count": 10,
+     "execution_count": 17,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -228,7 +229,7 @@
     "tr_lons = [r.lon for r in tr_recs]\n",
     "tr_lats = [r.lat for r in tr_recs]\n",
     "\n",
-    "pl.DataFrame({\n",
+    "df = pl.DataFrame({\n",
     "    \"test_lon\": test_lons, \n",
     "    \"test_lat\": test_lats,\n",
     "    \"kd_dist\": kd_dists,\n",
@@ -237,7 +238,8 @@
     "    \"tr_dist\": tr_dists,\n",
     "    \"tr_lon\": tr_lons,\n",
     "    \"tr_lat\": tr_lats,   \n",
-    "}).filter(pl.col(\"kd_dist\").ne(pl.col(\"tr_dist\")))"
+    "}).filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n",
+    "df"
    ]
   }
  ],
-- 
GitLab