kdtree.py 6.01 KB
Newer Older
1
"""
2 3
KDTree
------
4
An implementation of KDTree using Haversine Distance for GeoSpatial analysis.
5 6 7 8 9
Useful tool for quickly searching for nearest neighbours. The implementation is
a K=2 or 2DTree as only 2 dimensions (longitude and latitude) are used.

Haversine distances are used for comparisons, so that the spherical geometry
of the earth is accounted for.
10 11 12 13
"""

from . import Record
from numpy import inf
14
from typing import List, Optional, Tuple
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42


class KDTree:
    """
    A Haverine distance implementation of a balanced KDTree.

    This implementation is a _balanced_ KDTree, each leaf node should have the
    same number of points (or differ by 1 depending on the number of points
    the KDTree is intialised with).

    The KDTree partitions in each of the lon and lat dimensions alternatively
    in sequence by splitting at the median of the dimension of the points
    assigned to the branch.

    Parameters
    ----------
    points : list[Record]
        A list of GeoSpatialTools.Record instances.
    depth : int
        The current depth of the KDTree, you should set this to 0, it is used
        internally.
    max_depth : int
        The maximium depth of the KDTree. The leaf nodes will have depth no
        larger than this value. Leaf nodes will not be created if there is
        only 1 point in the branch.
    """

    def __init__(
43
        self, points: List[Record], depth: int = 0, max_depth: int = 20
44 45 46 47 48 49 50 51 52 53 54 55 56 57
    ) -> None:
        self.depth = depth
        n_points = len(points)

        if self.depth == max_depth or n_points < 2:
            self.points = points
            self.split = False
            return None

        self.axis = depth % 2
        self.variable = "lon" if self.axis == 0 else "lat"

        points.sort(key=lambda p: getattr(p, self.variable))
        split_index = n_points // 2
58
        self.partition_value = getattr(points[split_index - 1], self.variable)
59 60 61

        self.split = True

62
        # Left is points left of midpoint
63
        self.child_left = KDTree(points[:split_index], depth + 1)
64
        # Right is points right of midpoint
65 66 67 68 69
        self.child_right = KDTree(points[split_index:], depth + 1)

        return None

    def insert(self, point: Record) -> bool:
70 71 72 73 74
        """
        Insert a Record into the KDTree. May unbalance the KDTree.

        The point will not be inserted if it is already in the KDTree.
        """
75
        if not self.split:
76 77
            if point in self.points:
                return False
78 79 80 81 82
            self.points.append(point)
            return True

        if getattr(point, self.variable) < self.partition_value:
            return self.child_left.insert(point)
83
        elif getattr(point, self.variable) > self.partition_value:
84
            return self.child_right.insert(point)
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        else:
            r, _ = self.query(point)
            if point in r:
                return False
            self.child_left._insert(point)
            return True

    def _insert(self, point: Record) -> None:
        """Insert a point even if it already exists in the KDTree"""
        if not self.split:
            self.points.append(point)
            return
        if getattr(point, self.variable) <= self.partition_value:
            self.child_left._insert(point)
        else:
            self.child_right._insert(point)
        return
102 103 104 105 106 107 108 109 110 111

    def delete(self, point: Record) -> bool:
        """Delete a Record from the KDTree. May unbalance the KDTree"""
        if not self.split:
            try:
                self.points.remove(point)
                return True
            except ValueError:
                return False

112 113 114 115 116 117 118
        if getattr(point, self.variable) <= self.partition_value:
            if self.child_left.delete(point):
                return True
        if getattr(point, self.variable) >= self.partition_value:
            if self.child_right.delete(point):
                return True
        return False
119

120
    def query(self, point) -> Tuple[List[Record], float]:
121
        """Find the nearest Record within the KDTree to a query Record"""
122
        if point.lon < 0:
123
            point2 = Record(point.lon + 360, point.lat, fix_lon=False)
124
        else:
125
            point2 = Record(point.lon - 360, point.lat, fix_lon=False)
126 127 128 129

        r1, d1 = self._query(point)
        r2, d2 = self._query(point2)
        if d1 <= d2:
130
            return r1, d1
131
        else:
132
            return r2, d2
133 134

    def _query(
135 136
        self,
        point: Record,
137
        current_best: Optional[List[Record]] = None,
138
        best_distance: float = inf,
139
    ) -> Tuple[List[Record], float]:
140 141
        if current_best is None:
            current_best = list()
142 143 144 145
        if not self.split:
            for p in self.points:
                dist = point.distance(p)
                if dist < best_distance:
146
                    current_best = [p]
147
                    best_distance = dist
148 149
                elif dist == best_distance:
                    current_best.append(p)
150 151
            return current_best, best_distance

152
        if getattr(point, self.variable) <= self.partition_value:
153
            current_best, best_distance = self.child_left._query(
154 155 156 157
                point, current_best, best_distance
            )
            if (
                point.distance(self._get_partition_record(point))
158
                <= best_distance
159
            ):
160
                current_best, best_distance = self.child_right._query(
161 162 163
                    point, current_best, best_distance
                )
        else:
164
            current_best, best_distance = self.child_right._query(
165 166 167 168
                point, current_best, best_distance
            )
            if (
                point.distance(self._get_partition_record(point))
169
                <= best_distance
170
            ):
171
                current_best, best_distance = self.child_left._query(
172 173 174 175 176 177 178 179 180
                    point, current_best, best_distance
                )

        return current_best, best_distance

    def _get_partition_record(self, point: Record) -> Record:
        if self.variable == "lon":
            return Record(self.partition_value, point.lat)
        return Record(point.lon, self.partition_value)