neighbours.py 2.21 KB
Newer Older
1 2
"""Functions for finding nearest neighbours using bisection."""

3 4
from numpy import argmin
from bisect import bisect
5
from typing import List, TypeVar, Union
6
from datetime import date, datetime
7
from warnings import warn
8 9 10 11 12


Numeric = TypeVar("Numeric", int, float, datetime, date)


13
class SortedWarning(Warning):
14 15
    """Warning class for Sortedness"""

16 17 18 19
    pass


class SortedError(Exception):
20 21
    """Error class for Sortedness"""

22 23 24
    pass


25
def _find_nearest(vals: List[Numeric], test: Numeric) -> int:
26 27 28 29 30 31 32 33 34 35 36 37
    i = bisect(vals, test)  # Position that test would be inserted

    # Handle edges
    if i == 0 and test <= vals[0]:
        return 0
    elif i == len(vals) and test >= vals[-1]:
        return len(vals) - 1

    test_idx = [i - 1, i]
    return test_idx[argmin([abs(test - vals[j]) for j in test_idx])]


38
def find_nearest(
39 40
    vals: List[Numeric],
    test: Union[List[Numeric], Numeric],
41
    check_sorted: bool = True,
42
) -> Union[List[int], int]:
43 44 45 46 47
    """
    Find the nearest value in a list of values for each test value.

    Uses bisection for speediness!

48 49
    Parameters
    ----------
50 51 52
    vals : list[Numeric]
        List of values - this is the pool of values for which we are looking
        for a nearest match. This list MUST be sorted. Sortedness is not
53
        checked, nor is the list sorted.
54
    test : list[Numeric] | Numeric
55
        List of query values
56 57 58
    check_sorted : bool
        Optionally check that the input vals is sorted. Raises an error if set
        to True (default), displays a warning if set to False.
59 60

    Returns
61
    -------
62
    A list containing the index of the nearest neighbour in vals for each value
63
    in test. Or the index of the nearest neighbour if test is a single value.
64
    """
65 66 67 68 69 70 71 72 73 74
    if check_sorted:
        s = _check_sorted(vals)
        if not s:
            raise SortedError("Input values are not sorted")
    else:
        warn("Not checking sortedness of data", SortedWarning)

    if not isinstance(test, list):
        return _find_nearest(vals, test)

75
    return [_find_nearest(vals, t) for t in test]
76 77 78 79 80 81


def _check_sorted(vals: list[Numeric]) -> bool:
    return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all(
        vals[i + 1] <= vals[i] for i in range(len(vals) - 1)
    )