kdtree.py 5.8 KB
Newer Older
1 2 3 4 5 6 7
"""
An implementation of KDTree using Haversine Distance for GeoSpatial analysis.
Useful tool for quickly searching for nearest neighbours.
"""

from . import Record
from numpy import inf
8
from typing import List, Optional, Tuple
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36


class KDTree:
    """
    A Haverine distance implementation of a balanced KDTree.

    This implementation is a _balanced_ KDTree, each leaf node should have the
    same number of points (or differ by 1 depending on the number of points
    the KDTree is intialised with).

    The KDTree partitions in each of the lon and lat dimensions alternatively
    in sequence by splitting at the median of the dimension of the points
    assigned to the branch.

    Parameters
    ----------
    points : list[Record]
        A list of GeoSpatialTools.Record instances.
    depth : int
        The current depth of the KDTree, you should set this to 0, it is used
        internally.
    max_depth : int
        The maximium depth of the KDTree. The leaf nodes will have depth no
        larger than this value. Leaf nodes will not be created if there is
        only 1 point in the branch.
    """

    def __init__(
37
        self, points: List[Record], depth: int = 0, max_depth: int = 20
38 39 40 41 42 43 44 45 46 47 48 49 50 51
    ) -> None:
        self.depth = depth
        n_points = len(points)

        if self.depth == max_depth or n_points < 2:
            self.points = points
            self.split = False
            return None

        self.axis = depth % 2
        self.variable = "lon" if self.axis == 0 else "lat"

        points.sort(key=lambda p: getattr(p, self.variable))
        split_index = n_points // 2
52
        self.partition_value = getattr(points[split_index - 1], self.variable)
53 54 55

        self.split = True

56
        # Left is points left of midpoint
57
        self.child_left = KDTree(points[:split_index], depth + 1)
58
        # Right is points right of midpoint
59 60 61 62 63
        self.child_right = KDTree(points[split_index:], depth + 1)

        return None

    def insert(self, point: Record) -> bool:
64 65 66 67 68
        """
        Insert a Record into the KDTree. May unbalance the KDTree.

        The point will not be inserted if it is already in the KDTree.
        """
69
        if not self.split:
70 71
            if point in self.points:
                return False
72 73 74 75 76
            self.points.append(point)
            return True

        if getattr(point, self.variable) < self.partition_value:
            return self.child_left.insert(point)
77
        elif getattr(point, self.variable) > self.partition_value:
78
            return self.child_right.insert(point)
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        else:
            r, _ = self.query(point)
            if point in r:
                return False
            self.child_left._insert(point)
            return True

    def _insert(self, point: Record) -> None:
        """Insert a point even if it already exists in the KDTree"""
        if not self.split:
            self.points.append(point)
            return
        if getattr(point, self.variable) <= self.partition_value:
            self.child_left._insert(point)
        else:
            self.child_right._insert(point)
        return
96 97 98 99 100 101 102 103 104 105

    def delete(self, point: Record) -> bool:
        """Delete a Record from the KDTree. May unbalance the KDTree"""
        if not self.split:
            try:
                self.points.remove(point)
                return True
            except ValueError:
                return False

106 107 108 109 110 111 112
        if getattr(point, self.variable) <= self.partition_value:
            if self.child_left.delete(point):
                return True
        if getattr(point, self.variable) >= self.partition_value:
            if self.child_right.delete(point):
                return True
        return False
113

114
    def query(self, point) -> Tuple[List[Record], float]:
115
        """Find the nearest Record within the KDTree to a query Record"""
116
        if point.lon < 0:
117
            point2 = Record(point.lon + 360, point.lat, fix_lon=False)
118
        else:
119
            point2 = Record(point.lon - 360, point.lat, fix_lon=False)
120 121 122 123

        r1, d1 = self._query(point)
        r2, d2 = self._query(point2)
        if d1 <= d2:
124
            return r1, d1
125
        else:
126
            return r2, d2
127 128

    def _query(
129 130
        self,
        point: Record,
131
        current_best: Optional[List[Record]] = None,
132
        best_distance: float = inf,
133
    ) -> Tuple[List[Record], float]:
134 135
        if current_best is None:
            current_best = list()
136 137 138 139
        if not self.split:
            for p in self.points:
                dist = point.distance(p)
                if dist < best_distance:
140
                    current_best = [p]
141
                    best_distance = dist
142 143
                elif dist == best_distance:
                    current_best.append(p)
144 145
            return current_best, best_distance

146
        if getattr(point, self.variable) <= self.partition_value:
147
            current_best, best_distance = self.child_left._query(
148 149 150 151
                point, current_best, best_distance
            )
            if (
                point.distance(self._get_partition_record(point))
152
                <= best_distance
153
            ):
154
                current_best, best_distance = self.child_right._query(
155 156 157
                    point, current_best, best_distance
                )
        else:
158
            current_best, best_distance = self.child_right._query(
159 160 161 162
                point, current_best, best_distance
            )
            if (
                point.distance(self._get_partition_record(point))
163
                <= best_distance
164
            ):
165
                current_best, best_distance = self.child_left._query(
166 167 168 169 170 171 172 173 174
                    point, current_best, best_distance
                )

        return current_best, best_distance

    def _get_partition_record(self, point: Record) -> Record:
        if self.variable == "lon":
            return Record(self.partition_value, point.lat)
        return Record(point.lon, self.partition_value)