neighbours.py 2.11 KB
Newer Older
1 2
"""Functions for finding nearest neighbours using bisection."""

3 4 5 6
from numpy import argmin
from bisect import bisect
from typing import TypeVar
from datetime import date, datetime
7
from warnings import warn
8 9 10 11 12


Numeric = TypeVar("Numeric", int, float, datetime, date)


13 14 15 16 17 18 19 20
class SortedWarning(Warning):
    pass


class SortedError(Exception):
    pass


21 22 23 24 25 26 27 28 29 30 31 32 33
def _find_nearest(vals: list[Numeric], test: Numeric) -> int:
    i = bisect(vals, test)  # Position that test would be inserted

    # Handle edges
    if i == 0 and test <= vals[0]:
        return 0
    elif i == len(vals) and test >= vals[-1]:
        return len(vals) - 1

    test_idx = [i - 1, i]
    return test_idx[argmin([abs(test - vals[j]) for j in test_idx])]


34 35 36 37 38
def find_nearest(
    vals: list[Numeric],
    test: list[Numeric] | Numeric,
    check_sorted: bool = True,
) -> list[int] | int:
39 40 41 42 43
    """
    Find the nearest value in a list of values for each test value.

    Uses bisection for speediness!

44 45
    Parameters
    ----------
46 47 48
    vals : list[Numeric]
        List of values - this is the pool of values for which we are looking
        for a nearest match. This list MUST be sorted. Sortedness is not
49
        checked, nor is the list sorted.
50
    test : list[Numeric] | Numeric
51
        List of query values
52 53 54
    check_sorted : bool
        Optionally check that the input vals is sorted. Raises an error if set
        to True (default), displays a warning if set to False.
55 56

    Returns
57
    -------
58
    A list containing the index of the nearest neighbour in vals for each value
59
    in test. Or the index of the nearest neighbour if test is a single value.
60
    """
61 62 63 64 65 66 67 68 69 70
    if check_sorted:
        s = _check_sorted(vals)
        if not s:
            raise SortedError("Input values are not sorted")
    else:
        warn("Not checking sortedness of data", SortedWarning)

    if not isinstance(test, list):
        return _find_nearest(vals, test)

71
    return [_find_nearest(vals, t) for t in test]
72 73 74 75 76 77


def _check_sorted(vals: list[Numeric]) -> bool:
    return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all(
        vals[i + 1] <= vals[i] for i in range(len(vals) - 1)
    )