"""Functions for finding nearest neighbours using bisection.""" from numpy import argmin from bisect import bisect from typing import TypeVar from datetime import date, datetime from warnings import warn Numeric = TypeVar("Numeric", int, float, datetime, date) class SortedWarning(Warning): pass class SortedError(Exception): pass def _find_nearest(vals: list[Numeric], test: Numeric) -> int: i = bisect(vals, test) # Position that test would be inserted # Handle edges if i == 0 and test <= vals[0]: return 0 elif i == len(vals) and test >= vals[-1]: return len(vals) - 1 test_idx = [i - 1, i] return test_idx[argmin([abs(test - vals[j]) for j in test_idx])] def find_nearest( vals: list[Numeric], test: list[Numeric] | Numeric, check_sorted: bool = True, ) -> list[int] | int: """ Find the nearest value in a list of values for each test value. Uses bisection for speediness! Parameters ---------- vals : list[Numeric] List of values - this is the pool of values for which we are looking for a nearest match. This list MUST be sorted. Sortedness is not checked, nor is the list sorted. test : list[Numeric] List of query values check_sorted : bool Optionally check that the input vals is sorted. Raises an error if set to True (default), displays a warning if set to False. Returns ------- A list containing the index of the nearest neighbour in vals for each value in test. """ if check_sorted: s = _check_sorted(vals) if not s: raise SortedError("Input values are not sorted") else: warn("Not checking sortedness of data", SortedWarning) if not isinstance(test, list): return _find_nearest(vals, test) return [_find_nearest(vals, t) for t in test] def _check_sorted(vals: list[Numeric]) -> bool: return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all( vals[i + 1] <= vals[i] for i in range(len(vals) - 1) )