kdtree.py 5.17 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
"""
An implementation of KDTree using Haversine Distance for GeoSpatial analysis.
Useful tool for quickly searching for nearest neighbours.
"""

from . import Record
from numpy import inf


class KDTree:
    """
    A Haverine distance implementation of a balanced KDTree.

    This implementation is a _balanced_ KDTree, each leaf node should have the
    same number of points (or differ by 1 depending on the number of points
    the KDTree is intialised with).

    The KDTree partitions in each of the lon and lat dimensions alternatively
    in sequence by splitting at the median of the dimension of the points
    assigned to the branch.

    Parameters
    ----------
    points : list[Record]
        A list of GeoSpatialTools.Record instances.
    depth : int
        The current depth of the KDTree, you should set this to 0, it is used
        internally.
    max_depth : int
        The maximium depth of the KDTree. The leaf nodes will have depth no
        larger than this value. Leaf nodes will not be created if there is
        only 1 point in the branch.
    """

    def __init__(
        self, points: list[Record], depth: int = 0, max_depth: int = 20
    ) -> None:
        self.depth = depth
        n_points = len(points)

        if self.depth == max_depth or n_points < 2:
            self.points = points
            self.split = False
            return None

        self.axis = depth % 2
        self.variable = "lon" if self.axis == 0 else "lat"

        points.sort(key=lambda p: getattr(p, self.variable))
        split_index = n_points // 2
51 52 53 54 55 56 57
        self.partition_value = getattr(points[split_index - 1], self.variable)
        while (
            split_index < n_points
            and getattr(points[split_index], self.variable)
            == self.partition_value
        ):
            split_index += 1
58 59 60

        self.split = True

61
        # Left is <= median
62
        self.child_left = KDTree(points[:split_index], depth + 1)
63
        # Right is > median
64 65 66 67 68
        self.child_right = KDTree(points[split_index:], depth + 1)

        return None

    def insert(self, point: Record) -> bool:
69 70 71 72 73
        """
        Insert a Record into the KDTree. May unbalance the KDTree.

        The point will not be inserted if it is already in the KDTree.
        """
74
        if not self.split:
75 76
            if point in self.points:
                return False
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
            self.points.append(point)
            return True

        if getattr(point, self.variable) < self.partition_value:
            return self.child_left.insert(point)
        else:
            return self.child_right.insert(point)

    def delete(self, point: Record) -> bool:
        """Delete a Record from the KDTree. May unbalance the KDTree"""
        if not self.split:
            try:
                self.points.remove(point)
                return True
            except ValueError:
                return False

        if getattr(point, self.variable) < self.partition_value:
            return self.child_left.delete(point)
        else:
            return self.child_right.delete(point)

99
    def query(self, point) -> tuple[list[Record], float]:
100
        """Find the nearest Record within the KDTree to a query Record"""
101 102 103 104 105 106 107 108
        if point.lon < 0:
            point2 = Record(point.lon + 360, point.lat)
        else:
            point2 = Record(point.lon - 360, point.lat)

        r1, d1 = self._query(point)
        r2, d2 = self._query(point2)
        if d1 <= d2:
109
            return r1, d1
110
        else:
111
            return r2, d2
112 113

    def _query(
114 115
        self,
        point: Record,
116
        current_best: list[Record] | None = None,
117
        best_distance: float = inf,
118 119 120
    ) -> tuple[list[Record], float]:
        if current_best is None:
            current_best = list()
121 122 123 124
        if not self.split:
            for p in self.points:
                dist = point.distance(p)
                if dist < best_distance:
125
                    current_best = [p]
126
                    best_distance = dist
127 128
                elif dist == best_distance:
                    current_best.append(p)
129 130
            return current_best, best_distance

131
        if getattr(point, self.variable) <= self.partition_value:
132
            current_best, best_distance = self.child_left._query(
133 134 135 136
                point, current_best, best_distance
            )
            if (
                point.distance(self._get_partition_record(point))
137
                <= best_distance
138
            ):
139
                current_best, best_distance = self.child_right._query(
140 141 142
                    point, current_best, best_distance
                )
        else:
143
            current_best, best_distance = self.child_right._query(
144 145 146 147
                point, current_best, best_distance
            )
            if (
                point.distance(self._get_partition_record(point))
148
                <= best_distance
149
            ):
150
                current_best, best_distance = self.child_left._query(
151 152 153 154 155 156 157 158 159
                    point, current_best, best_distance
                )

        return current_best, best_distance

    def _get_partition_record(self, point: Record) -> Record:
        if self.variable == "lon":
            return Record(self.partition_value, point.lat)
        return Record(point.lon, self.partition_value)