neighbours.py 1.25 KB
Newer Older
1 2
"""Functions for finding nearest neighbours using bisection."""

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
from numpy import argmin
from bisect import bisect
from typing import TypeVar
from datetime import date, datetime


Numeric = TypeVar("Numeric", int, float, datetime, date)


def _find_nearest(vals: list[Numeric], test: Numeric) -> int:
    i = bisect(vals, test)  # Position that test would be inserted

    # Handle edges
    if i == 0 and test <= vals[0]:
        return 0
    elif i == len(vals) and test >= vals[-1]:
        return len(vals) - 1

    test_idx = [i - 1, i]
    return test_idx[argmin([abs(test - vals[j]) for j in test_idx])]


def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]:
    """
    Find the nearest value in a list of values for each test value.

    Uses bisection for speediness!

31 32
    Parameters
    ----------
33 34 35
    vals : list[Numeric]
        List of values - this is the pool of values for which we are looking
        for a nearest match. This list MUST be sorted. Sortedness is not
36
        checked, nor is the list sorted.
37 38 39 40
    test : list[Numeric]
        List of query values

    Returns
41
    -------
42 43 44 45
    A list containing the index of the nearest neighbour in vals for each value
    in test.
    """
    return [_find_nearest(vals, t) for t in test]