From 1836b2c48d06d8deedf79babafe56db2247719d6 Mon Sep 17 00:00:00 2001
From: josidd <joseph.siddons@noc.ac.uk>
Date: Fri, 4 Oct 2024 08:33:30 +0100
Subject: [PATCH] test(kdtree): Additional tests for KDTree edge cases

+ Test that query handles longitude wrap
+ Tests for duplicate positions
+ Test insert / delete duplicates
---
 test/test_kdtree.py | 64 ++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 63 insertions(+), 1 deletion(-)

diff --git a/test/test_kdtree.py b/test/test_kdtree.py
index 63a4ceb..85631bf 100644
--- a/test/test_kdtree.py
+++ b/test/test_kdtree.py
@@ -24,6 +24,18 @@ class TestKDTree(unittest.TestCase):
         assert kt.delete(delete_rec)
         assert delete_rec not in kt.child_right.child_right.points
 
+    def test_delete_dup(self):
+        test_records = [
+            Record(45, -23, uid="1"),
+            Record(45, -23, uid="2"),
+            Record(45, -23, uid="3"),
+            Record(45, -23, uid="4"),
+        ]
+        kt = KDTree(test_records, max_depth=3)
+        assert kt.delete(test_records[1])
+        # TEST: Cannot delete same record twice!
+        assert not kt.delete(test_records[1])
+
     def test_query(self):
         kt = KDTree(self.records)
         test_record = Record(-6, 35)
@@ -33,7 +45,57 @@ class TestKDTree(unittest.TestCase):
         true_record = self.records[true_ind]
 
         self.assertAlmostEqual(true_dist, best_dist)
-        assert best_record == true_record
+        assert len(best_record) == 1
+        assert best_record[0] == true_record
+
+    def test_duplicated_pos(self):
+        # TEST: That equal records get partitioned equally
+        test_records = [
+            Record(45, -23, uid="1"),
+            Record(45, -23, uid="2"),
+            Record(45, -23, uid="3"),
+            Record(45, -23, uid="4"),
+        ]
+        kt = KDTree(test_records, max_depth=3)
+        assert len(kt.child_left.child_left.points) == 1
+        assert len(kt.child_left.child_right.points) == 1
+        assert len(kt.child_right.child_left.points) == 1
+        assert len(kt.child_right.child_right.points) == 1
+
+    def test_insert_dup(self):
+        test_records = [
+            Record(45, -23, uid="1"),
+            Record(45, -23, uid="2"),
+            Record(45, -23, uid="3"),
+            Record(45, -23, uid="4"),
+        ]
+        kt = KDTree(test_records, max_depth=3)
+        assert not kt.insert(test_records[0])
+        assert not kt.insert(test_records[1])
+        assert not kt.insert(test_records[2])
+        assert not kt.insert(test_records[3])
+        assert kt.insert(Record(45, -23, uid="5"))
+        assert not kt.insert(Record(45, -23, uid="5"))
+        # TEST: Can insert after deleting
+        assert kt.delete(Record(45, -23, uid="5"))
+        assert kt.insert(Record(45, -23, uid="5"))
+
+    def test_get_multiple_neighbours(self):
+        kt = KDTree(self.records)
+        kt.insert(Record(45, -21, uid="1"))
+        kt.insert(Record(45, -21, uid="2"))
+
+        r, d = kt.query(Record(44, -21, uid="3"))
+        assert len(r) == 2
+
+    def test_wrap(self):
+        # TEST: Accounts for wrap at -180, 180
+        kt = KDTree(self.records)
+        kt.insert(Record(-160, -64, uid="G"))
+        query_rec = Record(-178, -79, uid="E")
+        r, _ = kt.query(query_rec)
+        assert len(r) == 1
+        assert r[0].uid == "C"
 
 
 if __name__ == "__main__":
-- 
GitLab