diff --git a/test/test_kdtree.py b/test/test_kdtree.py index 85631bf649742949043fa347017ac33c3fdd6f4f..edd1d826a46103b424317fbcec8ba0fb6ae3e8a7 100644 --- a/test/test_kdtree.py +++ b/test/test_kdtree.py @@ -1,4 +1,6 @@ import unittest +import random + from numpy import min, argmin from GeoSpatialTools import haversine, KDTree, Record @@ -91,12 +93,39 @@ class TestKDTree(unittest.TestCase): def test_wrap(self): # TEST: Accounts for wrap at -180, 180 kt = KDTree(self.records) - kt.insert(Record(-160, -64, uid="G")) + bad_rec = Record(-160, -64, uid="G") + kt.insert(bad_rec) query_rec = Record(-178, -79, uid="E") r, _ = kt.query(query_rec) assert len(r) == 1 assert r[0].uid == "C" + def test_near_pole_query(self): + test_records = [ + Record(-180, 89.5, uid="1"), + Record(-90, 89.9, uid="2"), + Record(0, 89.5, uid="3"), + ] + N_others = 50 + test_records.extend( + [ + Record( + random.choice(range(-180, 180)), + random.choice(range(80, 90)), + ) + for _ in range(N_others) + ] + ) + + kt = KDTree(test_records, max_depth=3) + + query_rec = Record(90, 89.8, uid="4") + r, d = kt.query(query_rec) + assert len(r) == 1 + print(r[0]) + print(d) + assert r[0].uid == "2" + if __name__ == "__main__": unittest.main()