diff --git a/GeoSpatialTools/__init__.py b/GeoSpatialTools/__init__.py index 8b137891791fe96927ad78e64b0aad7bded08bdc..4a2d1ad900512da6b40cadb362fe6f4f92ce34d0 100644 --- a/GeoSpatialTools/__init__.py +++ b/GeoSpatialTools/__init__.py @@ -1 +1,4 @@ +from .neighbours import find_nearest +from .distance_metrics import haversine +__all__ = ["find_nearest", "haversine"] diff --git a/GeoSpatialTools/neighbours.py b/GeoSpatialTools/neighbours.py new file mode 100644 index 0000000000000000000000000000000000000000..d13b637c6f70f0d8e6c50b023aa5cf74a0549f05 --- /dev/null +++ b/GeoSpatialTools/neighbours.py @@ -0,0 +1,44 @@ +from numpy import argmin +from bisect import bisect +from typing import TypeVar +from datetime import date, datetime + + +Numeric = TypeVar("Numeric", int, float, datetime, date) + + +def _find_nearest(vals: list[Numeric], test: Numeric) -> int: + i = bisect(vals, test) # Position that test would be inserted + + # Handle edges + if i == 0 and test <= vals[0]: + return 0 + elif i == len(vals) and test >= vals[-1]: + return len(vals) - 1 + + test_idx = [i - 1, i] + return test_idx[argmin([abs(test - vals[j]) for j in test_idx])] + + +def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]: + """ + Find the nearest value in a list of values for each test value. + + Uses bisection for speediness! + + Arguments + ========= + vals : list[Numeric] + List of values - this is the pool of values for which we are looking + for a nearest match. This list MUST be sorted. Sortedness is not + checked, nor is the list sorted. An error will be raised if the list + is not sorted. + test : list[Numeric] + List of query values + + Returns + ======= + A list containing the index of the nearest neighbour in vals for each value + in test. + """ + return [_find_nearest(vals, t) for t in test] diff --git a/pyproject.toml b/pyproject.toml index eb027504b21865e0d98b025f96814b4dc997a2c1..36af2d731edb08aa9f4735d31a914fcb2d32bd29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ packages = ["GeoSpatialTools"] name = "GeoSpatialTools" version = "0.1.1" dependencies = [ + "numpy", ] requires-python = ">=3.11" authors = [ diff --git a/test/test_neighbours.py b/test/test_neighbours.py new file mode 100644 index 0000000000000000000000000000000000000000..fd2a6e259b92cc0e97b3a22adebacb8ca9b505e4 --- /dev/null +++ b/test/test_neighbours.py @@ -0,0 +1,34 @@ +import unittest +from numpy import argmin +from random import choice, sample +from datetime import datetime, timedelta +from GeoSpatialTools import find_nearest + + +class TestFindNearest(unittest.TestCase): + dates = [ + datetime(2009, 1, 1, 0, 0) + timedelta(seconds=i * 3600) + for i in range(365 * 24) + ] + test_dates = sample(dates, 150) + test_dates = [ + d + timedelta(seconds=60 * choice(range(60))) for d in test_dates + ] + test_dates.append(dates[0]) + test_dates.append(dates[-1]) + test_dates.append(datetime(2004, 11, 15, 17, 28)) + test_dates.append(datetime(2013, 4, 22, 1, 41)) + + def test_nearest(self): + greedy = [ + argmin([abs(x - y) for x in self.dates]) for y in self.test_dates + ] + ours = find_nearest(self.dates, self.test_dates) + + assert ours == greedy + + pass + + +if __name__ == "__main__": + unittest.main()