Commit 023ffa3e authored by Joseph Siddons's avatar Joseph Siddons
Browse files

feat(1d_neighbours): add option to test sortedness

- raise error if not sorted.
- warn if not testing sortedness.
1 merge request!7feat(1d_neighbours): add option to test sortedness
......@@ -4,11 +4,20 @@ from numpy import argmin
from bisect import bisect
from typing import TypeVar
from datetime import date, datetime
from warnings import warn
Numeric = TypeVar("Numeric", int, float, datetime, date)
class SortedWarning(Warning):
pass
class SortedError(Exception):
pass
def _find_nearest(vals: list[Numeric], test: Numeric) -> int:
i = bisect(vals, test) # Position that test would be inserted
......@@ -22,7 +31,11 @@ def _find_nearest(vals: list[Numeric], test: Numeric) -> int:
return test_idx[argmin([abs(test - vals[j]) for j in test_idx])]
def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]:
def find_nearest(
vals: list[Numeric],
test: list[Numeric] | Numeric,
check_sorted: bool = True,
) -> list[int] | int:
"""
Find the nearest value in a list of values for each test value.
......@@ -36,10 +49,29 @@ def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]:
checked, nor is the list sorted.
test : list[Numeric]
List of query values
check_sorted : bool
Optionally check that the input vals is sorted. Raises an error if set
to True (default), displays a warning if set to False.
Returns
-------
A list containing the index of the nearest neighbour in vals for each value
in test.
"""
if check_sorted:
s = _check_sorted(vals)
if not s:
raise SortedError("Input values are not sorted")
else:
warn("Not checking sortedness of data", SortedWarning)
if not isinstance(test, list):
return _find_nearest(vals, test)
return [_find_nearest(vals, t) for t in test]
def _check_sorted(vals: list[Numeric]) -> bool:
return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all(
vals[i + 1] <= vals[i] for i in range(len(vals) - 1)
)
......@@ -3,6 +3,7 @@ from numpy import argmin
from random import choice, sample
from datetime import datetime, timedelta
from GeoSpatialTools import find_nearest
from GeoSpatialTools.neighbours import SortedError, SortedWarning
class TestFindNearest(unittest.TestCase):
......@@ -27,7 +28,13 @@ class TestFindNearest(unittest.TestCase):
assert ours == greedy
pass
def test_sorted_warn(self):
with self.assertWarns(SortedWarning):
find_nearest([1., 2., 3.], 2.3, check_sorted=False)
def test_sorted_error(self):
with self.assertRaises(SortedError):
find_nearest([3., 1., 2.], 2.3)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment