From 023ffa3eb1e8917955c1bd95ca9d2119ff52658f Mon Sep 17 00:00:00 2001 From: josidd <joseph.siddons@noc.ac.uk> Date: Tue, 8 Oct 2024 09:47:07 +0100 Subject: [PATCH] feat(1d_neighbours): add option to test sortedness - raise error if not sorted. - warn if not testing sortedness. --- GeoSpatialTools/neighbours.py | 34 +++++++++++++++++++++++++++++++++- test/test_neighbours.py | 9 ++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/GeoSpatialTools/neighbours.py b/GeoSpatialTools/neighbours.py index 82c562e..965fca6 100644 --- a/GeoSpatialTools/neighbours.py +++ b/GeoSpatialTools/neighbours.py @@ -4,11 +4,20 @@ from numpy import argmin from bisect import bisect from typing import TypeVar from datetime import date, datetime +from warnings import warn Numeric = TypeVar("Numeric", int, float, datetime, date) +class SortedWarning(Warning): + pass + + +class SortedError(Exception): + pass + + def _find_nearest(vals: list[Numeric], test: Numeric) -> int: i = bisect(vals, test) # Position that test would be inserted @@ -22,7 +31,11 @@ def _find_nearest(vals: list[Numeric], test: Numeric) -> int: return test_idx[argmin([abs(test - vals[j]) for j in test_idx])] -def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]: +def find_nearest( + vals: list[Numeric], + test: list[Numeric] | Numeric, + check_sorted: bool = True, +) -> list[int] | int: """ Find the nearest value in a list of values for each test value. @@ -36,10 +49,29 @@ def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]: checked, nor is the list sorted. test : list[Numeric] List of query values + check_sorted : bool + Optionally check that the input vals is sorted. Raises an error if set + to True (default), displays a warning if set to False. Returns ------- A list containing the index of the nearest neighbour in vals for each value in test. """ + if check_sorted: + s = _check_sorted(vals) + if not s: + raise SortedError("Input values are not sorted") + else: + warn("Not checking sortedness of data", SortedWarning) + + if not isinstance(test, list): + return _find_nearest(vals, test) + return [_find_nearest(vals, t) for t in test] + + +def _check_sorted(vals: list[Numeric]) -> bool: + return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all( + vals[i + 1] <= vals[i] for i in range(len(vals) - 1) + ) diff --git a/test/test_neighbours.py b/test/test_neighbours.py index fd2a6e2..e9fdbc7 100644 --- a/test/test_neighbours.py +++ b/test/test_neighbours.py @@ -3,6 +3,7 @@ from numpy import argmin from random import choice, sample from datetime import datetime, timedelta from GeoSpatialTools import find_nearest +from GeoSpatialTools.neighbours import SortedError, SortedWarning class TestFindNearest(unittest.TestCase): @@ -27,7 +28,13 @@ class TestFindNearest(unittest.TestCase): assert ours == greedy - pass + def test_sorted_warn(self): + with self.assertWarns(SortedWarning): + find_nearest([1., 2., 3.], 2.3, check_sorted=False) + + def test_sorted_error(self): + with self.assertRaises(SortedError): + find_nearest([3., 1., 2.], 2.3) if __name__ == "__main__": -- GitLab