neighbours.py 2.23 KB
Newer Older
1 2 3 4 5
"""
Neighbours
----------
Functions for finding nearest neighbours using bisection.
"""
6

7 8
from numpy import argmin
from bisect import bisect
9
from typing import List, TypeVar, Union
10
from datetime import date, datetime
11
from warnings import warn
12 13 14 15 16


Numeric = TypeVar("Numeric", int, float, datetime, date)


17
class SortedWarning(Warning):
18 19
    """Warning class for Sortedness"""

20 21 22 23
    pass


class SortedError(Exception):
24 25
    """Error class for Sortedness"""

26 27 28
    pass


29
def _find_nearest(vals: List[Numeric], test: Numeric) -> int:
30 31 32 33 34 35 36 37 38 39 40 41
    i = bisect(vals, test)  # Position that test would be inserted

    # Handle edges
    if i == 0 and test <= vals[0]:
        return 0
    elif i == len(vals) and test >= vals[-1]:
        return len(vals) - 1

    test_idx = [i - 1, i]
    return test_idx[argmin([abs(test - vals[j]) for j in test_idx])]


42
def find_nearest(
43 44
    vals: List[Numeric],
    test: Union[List[Numeric], Numeric],
45
    check_sorted: bool = True,
46
) -> Union[List[int], int]:
47 48 49 50 51
    """
    Find the nearest value in a list of values for each test value.

    Uses bisection for speediness!

52 53
    Parameters
    ----------
54 55 56
    vals : list[Numeric]
        List of values - this is the pool of values for which we are looking
        for a nearest match. This list MUST be sorted. Sortedness is not
57
        checked, nor is the list sorted.
58
    test : list[Numeric] | Numeric
59
        List of query values
60 61 62
    check_sorted : bool
        Optionally check that the input vals is sorted. Raises an error if set
        to True (default), displays a warning if set to False.
63 64

    Returns
65
    -------
66
    A list containing the index of the nearest neighbour in vals for each value
67
    in test. Or the index of the nearest neighbour if test is a single value.
68
    """
69 70 71 72 73 74 75 76 77 78
    if check_sorted:
        s = _check_sorted(vals)
        if not s:
            raise SortedError("Input values are not sorted")
    else:
        warn("Not checking sortedness of data", SortedWarning)

    if not isinstance(test, list):
        return _find_nearest(vals, test)

79
    return [_find_nearest(vals, t) for t in test]
80 81 82 83 84 85


def _check_sorted(vals: list[Numeric]) -> bool:
    return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all(
        vals[i + 1] <= vals[i] for i in range(len(vals) - 1)
    )