From 1836b2c48d06d8deedf79babafe56db2247719d6 Mon Sep 17 00:00:00 2001 From: josidd <joseph.siddons@noc.ac.uk> Date: Fri, 4 Oct 2024 08:33:30 +0100 Subject: [PATCH] test(kdtree): Additional tests for KDTree edge cases + Test that query handles longitude wrap + Tests for duplicate positions + Test insert / delete duplicates --- test/test_kdtree.py | 64 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/test/test_kdtree.py b/test/test_kdtree.py index 63a4ceb..85631bf 100644 --- a/test/test_kdtree.py +++ b/test/test_kdtree.py @@ -24,6 +24,18 @@ class TestKDTree(unittest.TestCase): assert kt.delete(delete_rec) assert delete_rec not in kt.child_right.child_right.points + def test_delete_dup(self): + test_records = [ + Record(45, -23, uid="1"), + Record(45, -23, uid="2"), + Record(45, -23, uid="3"), + Record(45, -23, uid="4"), + ] + kt = KDTree(test_records, max_depth=3) + assert kt.delete(test_records[1]) + # TEST: Cannot delete same record twice! + assert not kt.delete(test_records[1]) + def test_query(self): kt = KDTree(self.records) test_record = Record(-6, 35) @@ -33,7 +45,57 @@ class TestKDTree(unittest.TestCase): true_record = self.records[true_ind] self.assertAlmostEqual(true_dist, best_dist) - assert best_record == true_record + assert len(best_record) == 1 + assert best_record[0] == true_record + + def test_duplicated_pos(self): + # TEST: That equal records get partitioned equally + test_records = [ + Record(45, -23, uid="1"), + Record(45, -23, uid="2"), + Record(45, -23, uid="3"), + Record(45, -23, uid="4"), + ] + kt = KDTree(test_records, max_depth=3) + assert len(kt.child_left.child_left.points) == 1 + assert len(kt.child_left.child_right.points) == 1 + assert len(kt.child_right.child_left.points) == 1 + assert len(kt.child_right.child_right.points) == 1 + + def test_insert_dup(self): + test_records = [ + Record(45, -23, uid="1"), + Record(45, -23, uid="2"), + Record(45, -23, uid="3"), + Record(45, -23, uid="4"), + ] + kt = KDTree(test_records, max_depth=3) + assert not kt.insert(test_records[0]) + assert not kt.insert(test_records[1]) + assert not kt.insert(test_records[2]) + assert not kt.insert(test_records[3]) + assert kt.insert(Record(45, -23, uid="5")) + assert not kt.insert(Record(45, -23, uid="5")) + # TEST: Can insert after deleting + assert kt.delete(Record(45, -23, uid="5")) + assert kt.insert(Record(45, -23, uid="5")) + + def test_get_multiple_neighbours(self): + kt = KDTree(self.records) + kt.insert(Record(45, -21, uid="1")) + kt.insert(Record(45, -21, uid="2")) + + r, d = kt.query(Record(44, -21, uid="3")) + assert len(r) == 2 + + def test_wrap(self): + # TEST: Accounts for wrap at -180, 180 + kt = KDTree(self.records) + kt.insert(Record(-160, -64, uid="G")) + query_rec = Record(-178, -79, uid="E") + r, _ = kt.query(query_rec) + assert len(r) == 1 + assert r[0].uid == "C" if __name__ == "__main__": -- GitLab