diff --git a/GeoSpatialTools/neighbours.py b/GeoSpatialTools/neighbours.py index 82c562eddca9faee595e46ba8785431e52c7f0cc..965fca6050546b483c98cf70a2dea07a333ae3b5 100644 --- a/GeoSpatialTools/neighbours.py +++ b/GeoSpatialTools/neighbours.py @@ -4,11 +4,20 @@ from numpy import argmin from bisect import bisect from typing import TypeVar from datetime import date, datetime +from warnings import warn Numeric = TypeVar("Numeric", int, float, datetime, date) +class SortedWarning(Warning): + pass + + +class SortedError(Exception): + pass + + def _find_nearest(vals: list[Numeric], test: Numeric) -> int: i = bisect(vals, test) # Position that test would be inserted @@ -22,7 +31,11 @@ def _find_nearest(vals: list[Numeric], test: Numeric) -> int: return test_idx[argmin([abs(test - vals[j]) for j in test_idx])] -def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]: +def find_nearest( + vals: list[Numeric], + test: list[Numeric] | Numeric, + check_sorted: bool = True, +) -> list[int] | int: """ Find the nearest value in a list of values for each test value. @@ -36,10 +49,29 @@ def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]: checked, nor is the list sorted. test : list[Numeric] List of query values + check_sorted : bool + Optionally check that the input vals is sorted. Raises an error if set + to True (default), displays a warning if set to False. Returns ------- A list containing the index of the nearest neighbour in vals for each value in test. """ + if check_sorted: + s = _check_sorted(vals) + if not s: + raise SortedError("Input values are not sorted") + else: + warn("Not checking sortedness of data", SortedWarning) + + if not isinstance(test, list): + return _find_nearest(vals, test) + return [_find_nearest(vals, t) for t in test] + + +def _check_sorted(vals: list[Numeric]) -> bool: + return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all( + vals[i + 1] <= vals[i] for i in range(len(vals) - 1) + ) diff --git a/test/test_neighbours.py b/test/test_neighbours.py index fd2a6e259b92cc0e97b3a22adebacb8ca9b505e4..e9fdbc779358bf10e6e698a0bb9af3b9fdc9bdee 100644 --- a/test/test_neighbours.py +++ b/test/test_neighbours.py @@ -3,6 +3,7 @@ from numpy import argmin from random import choice, sample from datetime import datetime, timedelta from GeoSpatialTools import find_nearest +from GeoSpatialTools.neighbours import SortedError, SortedWarning class TestFindNearest(unittest.TestCase): @@ -27,7 +28,13 @@ class TestFindNearest(unittest.TestCase): assert ours == greedy - pass + def test_sorted_warn(self): + with self.assertWarns(SortedWarning): + find_nearest([1., 2., 3.], 2.3, check_sorted=False) + + def test_sorted_error(self): + with self.assertRaises(SortedError): + find_nearest([3., 1., 2.], 2.3) if __name__ == "__main__":