""" Neighbours ---------- Functions for finding nearest neighbours using bisection. """ from numpy import argmin from bisect import bisect from typing import List, TypeVar, Union from datetime import date, datetime from warnings import warn Numeric = TypeVar("Numeric", int, float, datetime, date) class SortedWarning(Warning): """Warning class for Sortedness""" pass class SortedError(Exception): """Error class for Sortedness""" pass def _find_nearest(vals: List[Numeric], test: Numeric) -> int: i = bisect(vals, test) # Position that test would be inserted # Handle edges if i == 0 and test <= vals[0]: return 0 elif i == len(vals) and test >= vals[-1]: return len(vals) - 1 test_idx = [i - 1, i] return test_idx[argmin([abs(test - vals[j]) for j in test_idx])] def find_nearest( vals: List[Numeric], test: Union[List[Numeric], Numeric], check_sorted: bool = True, ) -> Union[List[int], int]: """ Find the nearest value in a list of values for each test value. Uses bisection for speediness! Parameters ---------- vals : list[Numeric] List of values - this is the pool of values for which we are looking for a nearest match. This list MUST be sorted. Sortedness is not checked, nor is the list sorted. test : list[Numeric] | Numeric List of query values check_sorted : bool Optionally check that the input vals is sorted. Raises an error if set to True (default), displays a warning if set to False. Returns ------- A list containing the index of the nearest neighbour in vals for each value in test. Or the index of the nearest neighbour if test is a single value. """ if check_sorted: s = _check_sorted(vals) if not s: raise SortedError("Input values are not sorted") else: warn("Not checking sortedness of data", SortedWarning) if not isinstance(test, list): return _find_nearest(vals, test) return [_find_nearest(vals, t) for t in test] def _check_sorted(vals: list[Numeric]) -> bool: return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all( vals[i + 1] <= vals[i] for i in range(len(vals) - 1) )