diff --git a/GeoSpatialTools/neighbours.py b/GeoSpatialTools/neighbours.py
index 82c562eddca9faee595e46ba8785431e52c7f0cc..965fca6050546b483c98cf70a2dea07a333ae3b5 100644
--- a/GeoSpatialTools/neighbours.py
+++ b/GeoSpatialTools/neighbours.py
@@ -4,11 +4,20 @@ from numpy import argmin
 from bisect import bisect
 from typing import TypeVar
 from datetime import date, datetime
+from warnings import warn
 
 
 Numeric = TypeVar("Numeric", int, float, datetime, date)
 
 
+class SortedWarning(Warning):
+    pass
+
+
+class SortedError(Exception):
+    pass
+
+
 def _find_nearest(vals: list[Numeric], test: Numeric) -> int:
     i = bisect(vals, test)  # Position that test would be inserted
 
@@ -22,7 +31,11 @@ def _find_nearest(vals: list[Numeric], test: Numeric) -> int:
     return test_idx[argmin([abs(test - vals[j]) for j in test_idx])]
 
 
-def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]:
+def find_nearest(
+    vals: list[Numeric],
+    test: list[Numeric] | Numeric,
+    check_sorted: bool = True,
+) -> list[int] | int:
     """
     Find the nearest value in a list of values for each test value.
 
@@ -36,10 +49,29 @@ def find_nearest(vals: list[Numeric], test: list[Numeric]) -> list[int]:
         checked, nor is the list sorted.
     test : list[Numeric]
         List of query values
+    check_sorted : bool
+        Optionally check that the input vals is sorted. Raises an error if set
+        to True (default), displays a warning if set to False.
 
     Returns
     -------
     A list containing the index of the nearest neighbour in vals for each value
     in test.
     """
+    if check_sorted:
+        s = _check_sorted(vals)
+        if not s:
+            raise SortedError("Input values are not sorted")
+    else:
+        warn("Not checking sortedness of data", SortedWarning)
+
+    if not isinstance(test, list):
+        return _find_nearest(vals, test)
+
     return [_find_nearest(vals, t) for t in test]
+
+
+def _check_sorted(vals: list[Numeric]) -> bool:
+    return all(vals[i + 1] >= vals[i] for i in range(len(vals) - 1)) or all(
+        vals[i + 1] <= vals[i] for i in range(len(vals) - 1)
+    )
diff --git a/test/test_neighbours.py b/test/test_neighbours.py
index fd2a6e259b92cc0e97b3a22adebacb8ca9b505e4..e9fdbc779358bf10e6e698a0bb9af3b9fdc9bdee 100644
--- a/test/test_neighbours.py
+++ b/test/test_neighbours.py
@@ -3,6 +3,7 @@ from numpy import argmin
 from random import choice, sample
 from datetime import datetime, timedelta
 from GeoSpatialTools import find_nearest
+from GeoSpatialTools.neighbours import SortedError, SortedWarning
 
 
 class TestFindNearest(unittest.TestCase):
@@ -27,7 +28,13 @@ class TestFindNearest(unittest.TestCase):
 
         assert ours == greedy
 
-    pass
+    def test_sorted_warn(self):
+        with self.assertWarns(SortedWarning):
+            find_nearest([1., 2., 3.], 2.3, check_sorted=False)
+
+    def test_sorted_error(self):
+        with self.assertRaises(SortedError):
+            find_nearest([3., 1., 2.], 2.3)
 
 
 if __name__ == "__main__":