diff --git a/GeoSpatialTools/__init__.py b/GeoSpatialTools/__init__.py
index 4a2d1ad900512da6b40cadb362fe6f4f92ce34d0..11e1e9388c3b2a90fa76a7649632b55e77c63ef1 100644
--- a/GeoSpatialTools/__init__.py
+++ b/GeoSpatialTools/__init__.py
@@ -1,4 +1,14 @@
 from .neighbours import find_nearest
 from .distance_metrics import haversine
+from .quadtree import Ellipse, QuadTree, Record, Rectangle
+from .kdtree import KDTree
 
-__all__ = ["find_nearest", "haversine"]
+__all__ = [
+    "Ellipse",
+    "KDTree",
+    "QuadTree",
+    "Record",
+    "Rectangle",
+    "find_nearest",
+    "haversine",
+]
diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd54c1c376b7c5aea2cc62f70ed1b96036402368
--- /dev/null
+++ b/GeoSpatialTools/kdtree.py
@@ -0,0 +1,128 @@
+"""
+An implementation of KDTree using Haversine Distance for GeoSpatial analysis.
+Useful tool for quickly searching for nearest neighbours.
+"""
+
+from . import Record
+from numpy import inf
+
+
+class KDTree:
+    """
+    A Haverine distance implementation of a balanced KDTree.
+
+    This implementation is a _balanced_ KDTree, each leaf node should have the
+    same number of points (or differ by 1 depending on the number of points
+    the KDTree is intialised with).
+
+    The KDTree partitions in each of the lon and lat dimensions alternatively
+    in sequence by splitting at the median of the dimension of the points
+    assigned to the branch.
+
+    Parameters
+    ----------
+    points : list[Record]
+        A list of GeoSpatialTools.Record instances.
+    depth : int
+        The current depth of the KDTree, you should set this to 0, it is used
+        internally.
+    max_depth : int
+        The maximium depth of the KDTree. The leaf nodes will have depth no
+        larger than this value. Leaf nodes will not be created if there is
+        only 1 point in the branch.
+    """
+
+    def __init__(
+        self, points: list[Record], depth: int = 0, max_depth: int = 20
+    ) -> None:
+        self.depth = depth
+        n_points = len(points)
+
+        if self.depth == max_depth or n_points < 2:
+            self.points = points
+            self.split = False
+            return None
+
+        self.axis = depth % 2
+        self.variable = "lon" if self.axis == 0 else "lat"
+
+        points.sort(key=lambda p: getattr(p, self.variable))
+        split_index = n_points // 2
+
+        self.partition_value = getattr(points[split_index], self.variable)
+        self.split = True
+
+        self.child_left = KDTree(points[:split_index], depth + 1)
+        self.child_right = KDTree(points[split_index:], depth + 1)
+
+        return None
+
+    def insert(self, point: Record) -> bool:
+        """Insert a Record into the KDTree. May unbalance the KDTree"""
+        if not self.split:
+            self.points.append(point)
+            return True
+
+        if getattr(point, self.variable) < self.partition_value:
+            return self.child_left.insert(point)
+        else:
+            return self.child_right.insert(point)
+
+    def delete(self, point: Record) -> bool:
+        """Delete a Record from the KDTree. May unbalance the KDTree"""
+        if not self.split:
+            try:
+                self.points.remove(point)
+                return True
+            except ValueError:
+                return False
+
+        if getattr(point, self.variable) < self.partition_value:
+            return self.child_left.delete(point)
+        else:
+            return self.child_right.delete(point)
+
+    def query(
+        self,
+        point: Record,
+        current_best: Record | None = None,
+        best_distance: float = inf,
+    ) -> tuple[Record | None, float]:
+        """Find the nearest Record within the KDTree to a query Record"""
+        if not self.split:
+            for p in self.points:
+                dist = point.distance(p)
+                if dist < best_distance:
+                    current_best = p
+                    best_distance = dist
+            return current_best, best_distance
+
+        if getattr(point, self.variable) < self.partition_value:
+            current_best, best_distance = self.child_left.query(
+                point, current_best, best_distance
+            )
+            if (
+                point.distance(self._get_partition_record(point))
+                < best_distance
+            ):
+                current_best, best_distance = self.child_right.query(
+                    point, current_best, best_distance
+                )
+        else:
+            current_best, best_distance = self.child_right.query(
+                point, current_best, best_distance
+            )
+            if (
+                point.distance(self._get_partition_record(point))
+                < best_distance
+            ):
+                current_best, best_distance = self.child_left.query(
+                    point, current_best, best_distance
+                )
+
+        return current_best, best_distance
+
+    def _get_partition_record(self, point: Record) -> Record:
+        if self.variable == "lon":
+            return Record(self.partition_value, point.lat)
+        return Record(point.lon, self.partition_value)
diff --git a/test/test_kdtree.py b/test/test_kdtree.py
new file mode 100644
index 0000000000000000000000000000000000000000..63a4ceb24761279eb261ab65faa921b73a8d6888
--- /dev/null
+++ b/test/test_kdtree.py
@@ -0,0 +1,40 @@
+import unittest
+from numpy import min, argmin
+from GeoSpatialTools import haversine, KDTree, Record
+
+
+class TestKDTree(unittest.TestCase):
+    records = [
+        Record(1, 2, uid="A"),
+        Record(-9, 44, uid="B"),
+        Record(174, -81, uid="C"),
+        Record(-4, 71, uid="D"),
+    ]
+
+    def test_insert(self):
+        kt = KDTree(self.records)
+        test_record = Record(175, 44)
+        assert kt.insert(test_record)
+        assert test_record in kt.child_right.child_right.points
+
+    def test_delete(self):
+        kt = KDTree(self.records)
+        delete_rec = self.records[2]
+        assert delete_rec in kt.child_right.child_right.points
+        assert kt.delete(delete_rec)
+        assert delete_rec not in kt.child_right.child_right.points
+
+    def test_query(self):
+        kt = KDTree(self.records)
+        test_record = Record(-6, 35)
+        best_record, best_dist = kt.query(test_record)
+        true_dist = min([test_record.distance(r) for r in self.records])
+        true_ind = argmin([test_record.distance(r) for r in self.records])
+        true_record = self.records[true_ind]
+
+        self.assertAlmostEqual(true_dist, best_dist)
+        assert best_record == true_record
+
+
+if __name__ == "__main__":
+    unittest.main()