Commit 023ffa3e authored by Joseph Siddons's avatar Joseph Siddons
Browse files

feat(1d_neighbours): add option to test sortedness

- raise error if not sorted.
- warn if not testing sortedness.
parent 3c6aed6a
...@@ -4,11 +4,20 @@ from numpy import argmin ...@@ -4,11 +4,20 @@ from numpy import argmin
from bisect import bisect from bisect import bisect
from typing import TypeVar from typing import TypeVar
from datetime import date, datetime from datetime import date, datetime
from warnings import warn
Numeric = TypeVar("Numeric", int, float, datetime, date) Numeric = TypeVar("Numeric", int, float, datetime, date)
class SortedWarning(Warning):
pass
class SortedError(Exception):
pass
def _find_nearest(vals: list[Numeric], test: Numeric) -> int: def _find_nearest(vals: list[Numeric], test: Numeric) -> int:
i = bisect(vals, test) # Position that test would be inserted i = bisect(vals, test) # Position that test would be inserted
...@@ -22,7 +31,11 @@ def _find_nearest(vals: list[Numeric], test: Numeric) -> int: ...@@ -22,7 +31,11 @@ def _find_nearest(vals: list[Numeric], test: Numeric) -> int:
return test_idx[argmin([abs(test - vals[j]) for j in test_idx])] return test_idx[argmin([abs(test - vals[j]) for j in test_idx])]
def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]: def find_nearest(
vals: list[Numeric],
test: list[Numeric] | Numeric,
check_sorted: bool = True,
) -> list[int] | int:
""" """
Find the nearest value in a list of values for each test value. Find the nearest value in a list of values for each test value.
...@@ -36,10 +49,29 @@ def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]: ...@@ -36,10 +49,29 @@ def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]:
checked, nor is the list sorted. checked, nor is the list sorted.
test : list[Numeric] test : list[Numeric]
List of query values List of query values
check_sorted : bool
Optionally check that the input vals is sorted. Raises an error if set
to True (default), displays a warning if set to False.
Returns Returns
------- -------
A list containing the index of the nearest neighbour in vals for each value A list containing the index of the nearest neighbour in vals for each value
in test. in test.
""" """
if check_sorted:
s = _check_sorted(vals)
if not s:
raise SortedError("Input values are not sorted")
else:
warn("Not checking sortedness of data", SortedWarning)
if not isinstance(test, list):
return _find_nearest(vals, test)
return [_find_nearest(vals, t) for t in test] return [_find_nearest(vals, t) for t in test]
def _check_sorted(vals: list[Numeric]) -> bool:
return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all(
vals[i + 1] <= vals[i] for i in range(len(vals) - 1)
)
...@@ -3,6 +3,7 @@ from numpy import argmin ...@@ -3,6 +3,7 @@ from numpy import argmin
from random import choice, sample from random import choice, sample
from datetime import datetime, timedelta from datetime import datetime, timedelta
from GeoSpatialTools import find_nearest from GeoSpatialTools import find_nearest
from GeoSpatialTools.neighbours import SortedError, SortedWarning
class TestFindNearest(unittest.TestCase): class TestFindNearest(unittest.TestCase):
...@@ -27,7 +28,13 @@ class TestFindNearest(unittest.TestCase): ...@@ -27,7 +28,13 @@ class TestFindNearest(unittest.TestCase):
assert ours == greedy assert ours == greedy
pass def test_sorted_warn(self):
with self.assertWarns(SortedWarning):
find_nearest([1., 2., 3.], 2.3, check_sorted=False)
def test_sorted_error(self):
with self.assertRaises(SortedError):
find_nearest([3., 1., 2.], 2.3)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment