"""Functions for finding nearest neighbours using bisection."""

from numpy import argmin
from bisect import bisect
from typing import TypeVar
from datetime import date, datetime
from warnings import warn


Numeric = TypeVar("Numeric", int, float, datetime, date)


class SortedWarning(Warning):
    pass


class SortedError(Exception):
    pass


def _find_nearest(vals: list[Numeric], test: Numeric) -> int:
    i = bisect(vals, test)  # Position that test would be inserted

    # Handle edges
    if i == 0 and test <= vals[0]:
        return 0
    elif i == len(vals) and test >= vals[-1]:
        return len(vals) - 1

    test_idx = [i - 1, i]
    return test_idx[argmin([abs(test - vals[j]) for j in test_idx])]


def find_nearest(
    vals: list[Numeric],
    test: list[Numeric] | Numeric,
    check_sorted: bool = True,
) -> list[int] | int:
    """
    Find the nearest value in a list of values for each test value.

    Uses bisection for speediness!

    Parameters
    ----------
    vals : list[Numeric]
        List of values - this is the pool of values for which we are looking
        for a nearest match. This list MUST be sorted. Sortedness is not
        checked, nor is the list sorted.
    test : list[Numeric]
        List of query values
    check_sorted : bool
        Optionally check that the input vals is sorted. Raises an error if set
        to True (default), displays a warning if set to False.

    Returns
    -------
    A list containing the index of the nearest neighbour in vals for each value
    in test.
    """
    if check_sorted:
        s = _check_sorted(vals)
        if not s:
            raise SortedError("Input values are not sorted")
    else:
        warn("Not checking sortedness of data", SortedWarning)

    if not isinstance(test, list):
        return _find_nearest(vals, test)

    return [_find_nearest(vals, t) for t in test]


def _check_sorted(vals: list[Numeric]) -> bool:
    return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all(
        vals[i + 1] <= vals[i] for i in range(len(vals) - 1)
    )