diff --git a/GeoSpatialTools/kdtree.py b/GeoSpatialTools/kdtree.py index dba22f1d5980190ad4ac468fe0562b23dcd4fc56..c77524ad3a9d27f2d756cd17c292bfd330740793 100644 --- a/GeoSpatialTools/kdtree.py +++ b/GeoSpatialTools/kdtree.py @@ -49,18 +49,18 @@ class KDTree: points.sort(key=lambda p: getattr(p, self.variable)) split_index = n_points // 2 self.partition_value = getattr(points[split_index - 1], self.variable) - while ( - split_index < n_points - and getattr(points[split_index], self.variable) - == self.partition_value - ): - split_index += 1 + # while ( + # split_index < n_points + # and getattr(points[split_index], self.variable) + # == self.partition_value + # ): + # split_index += 1 self.split = True - # Left is <= median + # Left is points left of midpoint self.child_left = KDTree(points[:split_index], depth + 1) - # Right is > median + # Right is points right of midpoint self.child_right = KDTree(points[split_index:], depth + 1) return None @@ -79,8 +79,25 @@ class KDTree: if getattr(point, self.variable) < self.partition_value: return self.child_left.insert(point) - else: + elif getattr(point, self.variable) > self.partition_value: return self.child_right.insert(point) + else: + r, _ = self.query(point) + if point in r: + return False + self.child_left._insert(point) + return True + + def _insert(self, point: Record) -> None: + """Insert a point even if it already exists in the KDTree""" + if not self.split: + self.points.append(point) + return + if getattr(point, self.variable) <= self.partition_value: + self.child_left._insert(point) + else: + self.child_right._insert(point) + return def delete(self, point: Record) -> bool: """Delete a Record from the KDTree. May unbalance the KDTree""" @@ -91,12 +108,15 @@ class KDTree: except ValueError: return False - if getattr(point, self.variable) < self.partition_value: - return self.child_left.delete(point) - else: - return self.child_right.delete(point) + if getattr(point, self.variable) <= self.partition_value: + if self.child_left.delete(point): + return True + if getattr(point, self.variable) >= self.partition_value: + if self.child_right.delete(point): + return True + return False - def query(self, point) -> tuple[Record | None, float]: + def query(self, point) -> tuple[list[Record], float]: """Find the nearest Record within the KDTree to a query Record""" if point.lon < 0: point2 = Record(point.lon + 360, point.lat) @@ -106,32 +126,35 @@ class KDTree: r1, d1 = self._query(point) r2, d2 = self._query(point2) if d1 <= d2: - r = r1 + return r1, d1 else: - r = r2 - return r, point.distance(r) + return r2, d2 def _query( self, point: Record, - current_best: Record | None = None, + current_best: list[Record] | None = None, best_distance: float = inf, - ) -> tuple[Record | None, float]: + ) -> tuple[list[Record], float]: + if current_best is None: + current_best = list() if not self.split: for p in self.points: dist = point.distance(p) if dist < best_distance: - current_best = p + current_best = [p] best_distance = dist + elif dist == best_distance: + current_best.append(p) return current_best, best_distance - if getattr(point, self.variable) < self.partition_value: + if getattr(point, self.variable) <= self.partition_value: current_best, best_distance = self.child_left._query( point, current_best, best_distance ) if ( point.distance(self._get_partition_record(point)) - < best_distance + <= best_distance ): current_best, best_distance = self.child_right._query( point, current_best, best_distance @@ -142,7 +165,7 @@ class KDTree: ) if ( point.distance(self._get_partition_record(point)) - < best_distance + <= best_distance ): current_best, best_distance = self.child_left._query( point, current_best, best_distance diff --git a/notebooks/kdtree.ipynb b/notebooks/kdtree.ipynb index 9a69a4ef9e5990b8a55f33dc23190f0d673dc285..3ef3161b291da3b62a137efce907d3f40c55d184 100644 --- a/notebooks/kdtree.ipynb +++ b/notebooks/kdtree.ipynb @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 11, "id": "c60b30de-f864-477a-a09a-5f1caa4d9b9a", "metadata": {}, "outputs": [ @@ -56,18 +56,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "(14222, 2)\n", + "(16000, 2)\n", "shape: (5, 2)\n", "┌──────┬─────â”\n", "│ lon ┆ lat │\n", "│ --- ┆ --- │\n", "│ i64 ┆ i64 │\n", "â•žâ•â•â•â•â•â•â•ªâ•â•â•â•â•â•¡\n", - "│ -30 ┆ -41 │\n", - "│ -149 ┆ 56 │\n", - "│ 7 ┆ -68 │\n", - "│ -48 ┆ 83 │\n", - "│ -126 ┆ -35 │\n", + "│ 16 ┆ -75 │\n", + "│ 144 ┆ -77 │\n", + "│ -173 ┆ -83 │\n", + "│ 142 ┆ -81 │\n", + "│ -50 ┆ -38 │\n", "└──────┴─────┘\n" ] } @@ -83,14 +83,14 @@ "# dates_use = dates.sample(N, with_replacement=True).alias(\"datetime\")\n", "# uids = pl.Series(\"uid\", [generate_uid(8) for _ in range(N)])\n", "\n", - "df = pl.DataFrame([lons_use, lats_use]).unique()\n", + "df = pl.DataFrame([lons_use, lats_use])\n", "print(df.shape)\n", "print(df.head())" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 12, "id": "875f2a67-49fe-476f-add1-b1d76c6cd8f9", "metadata": {}, "outputs": [], @@ -100,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 13, "id": "1e883e5a-5086-4c29-aff2-d308874eae16", "metadata": {}, "outputs": [ @@ -108,8 +108,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 43.5 ms, sys: 3.43 ms, total: 46.9 ms\n", - "Wall time: 46.8 ms\n" + "CPU times: user 82 ms, sys: 4.14 ms, total: 86.1 ms\n", + "Wall time: 84.3 ms\n" ] } ], @@ -120,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 14, "id": "69022ad1-5ec8-4a09-836c-273ef452451f", "metadata": {}, "outputs": [ @@ -128,7 +128,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "173 μs ± 1.36 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + "188 μs ± 3.45 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], @@ -140,7 +140,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 15, "id": "28031966-c7d0-4201-a467-37590118e851", "metadata": {}, "outputs": [ @@ -148,7 +148,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "7.71 ms ± 38.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "8.72 ms ± 74.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -160,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 16, "id": "0d10b2ba-57b2-475c-9d01-135363423990", "metadata": {}, "outputs": [ @@ -168,26 +168,27 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 15.4 s, sys: 37.8 ms, total: 15.5 s\n", - "Wall time: 15.5 s\n" + "CPU times: user 17.3 s, sys: 31.6 ms, total: 17.3 s\n", + "Wall time: 17.3 s\n" ] } ], "source": [ "%%time\n", "n_samples = 1000\n", + "tol = 1e-8\n", "test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]\n", "kd_res = [kt.query(r) for r in test_records]\n", - "kd_recs = [_[0] for _ in kd_res]\n", + "kd_recs = [_[0][0] for _ in kd_res]\n", "kd_dists = [_[1] for _ in kd_res]\n", "tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]\n", "tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n", - "assert kd_dists == tr_dists, \"NOT MATCHING?\"" + "assert all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]), \"NOT MATCHING?\"" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 17, "id": "a6aa6926-7fd5-4fff-bd20-7bc0305b948d", "metadata": {}, "outputs": [ @@ -213,7 +214,7 @@ "└──────────┴──────────┴─────────┴────────┴────────┴─────────┴────────┴────────┘" ] }, - "execution_count": 10, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -228,7 +229,7 @@ "tr_lons = [r.lon for r in tr_recs]\n", "tr_lats = [r.lat for r in tr_recs]\n", "\n", - "pl.DataFrame({\n", + "df = pl.DataFrame({\n", " \"test_lon\": test_lons, \n", " \"test_lat\": test_lats,\n", " \"kd_dist\": kd_dists,\n", @@ -237,7 +238,8 @@ " \"tr_dist\": tr_dists,\n", " \"tr_lon\": tr_lons,\n", " \"tr_lat\": tr_lats, \n", - "}).filter(pl.col(\"kd_dist\").ne(pl.col(\"tr_dist\")))" + "}).filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n", + "df" ] } ], diff --git a/test/test_kdtree.py b/test/test_kdtree.py index 63a4ceb24761279eb261ab65faa921b73a8d6888..85631bf649742949043fa347017ac33c3fdd6f4f 100644 --- a/test/test_kdtree.py +++ b/test/test_kdtree.py @@ -24,6 +24,18 @@ class TestKDTree(unittest.TestCase): assert kt.delete(delete_rec) assert delete_rec not in kt.child_right.child_right.points + def test_delete_dup(self): + test_records = [ + Record(45, -23, uid="1"), + Record(45, -23, uid="2"), + Record(45, -23, uid="3"), + Record(45, -23, uid="4"), + ] + kt = KDTree(test_records, max_depth=3) + assert kt.delete(test_records[1]) + # TEST: Cannot delete same record twice! + assert not kt.delete(test_records[1]) + def test_query(self): kt = KDTree(self.records) test_record = Record(-6, 35) @@ -33,7 +45,57 @@ class TestKDTree(unittest.TestCase): true_record = self.records[true_ind] self.assertAlmostEqual(true_dist, best_dist) - assert best_record == true_record + assert len(best_record) == 1 + assert best_record[0] == true_record + + def test_duplicated_pos(self): + # TEST: That equal records get partitioned equally + test_records = [ + Record(45, -23, uid="1"), + Record(45, -23, uid="2"), + Record(45, -23, uid="3"), + Record(45, -23, uid="4"), + ] + kt = KDTree(test_records, max_depth=3) + assert len(kt.child_left.child_left.points) == 1 + assert len(kt.child_left.child_right.points) == 1 + assert len(kt.child_right.child_left.points) == 1 + assert len(kt.child_right.child_right.points) == 1 + + def test_insert_dup(self): + test_records = [ + Record(45, -23, uid="1"), + Record(45, -23, uid="2"), + Record(45, -23, uid="3"), + Record(45, -23, uid="4"), + ] + kt = KDTree(test_records, max_depth=3) + assert not kt.insert(test_records[0]) + assert not kt.insert(test_records[1]) + assert not kt.insert(test_records[2]) + assert not kt.insert(test_records[3]) + assert kt.insert(Record(45, -23, uid="5")) + assert not kt.insert(Record(45, -23, uid="5")) + # TEST: Can insert after deleting + assert kt.delete(Record(45, -23, uid="5")) + assert kt.insert(Record(45, -23, uid="5")) + + def test_get_multiple_neighbours(self): + kt = KDTree(self.records) + kt.insert(Record(45, -21, uid="1")) + kt.insert(Record(45, -21, uid="2")) + + r, d = kt.query(Record(44, -21, uid="3")) + assert len(r) == 2 + + def test_wrap(self): + # TEST: Accounts for wrap at -180, 180 + kt = KDTree(self.records) + kt.insert(Record(-160, -64, uid="G")) + query_rec = Record(-178, -79, uid="E") + r, _ = kt.query(query_rec) + assert len(r) == 1 + assert r[0].uid == "C" if __name__ == "__main__":