Commit 2fb05223 authored by Joseph Siddons's avatar Joseph Siddons
Browse files

Merge branch '3-potential-infinite-loop-in-kdtree' into 'main'

Resolve "Potential Infinite Loop in KDTree"

Closes #3

See merge request josidd/geospatialtools!5
parents f9bc5809 98f65e45
......@@ -49,18 +49,18 @@ class KDTree:
points.sort(key=lambda p: getattr(p, self.variable))
split_index = n_points // 2
self.partition_value = getattr(points[split_index - 1], self.variable)
while (
split_index < n_points
and getattr(points[split_index], self.variable)
== self.partition_value
):
split_index += 1
# while (
# split_index < n_points
# and getattr(points[split_index], self.variable)
# == self.partition_value
# ):
# split_index += 1
self.split = True
# Left is <= median
# Left is points left of midpoint
self.child_left = KDTree(points[:split_index], depth + 1)
# Right is > median
# Right is points right of midpoint
self.child_right = KDTree(points[split_index:], depth + 1)
return None
......@@ -79,8 +79,25 @@ class KDTree:
if getattr(point, self.variable) < self.partition_value:
return self.child_left.insert(point)
else:
elif getattr(point, self.variable) > self.partition_value:
return self.child_right.insert(point)
else:
r, _ = self.query(point)
if point in r:
return False
self.child_left._insert(point)
return True
def _insert(self, point: Record) -> None:
"""Insert a point even if it already exists in the KDTree"""
if not self.split:
self.points.append(point)
return
if getattr(point, self.variable) <= self.partition_value:
self.child_left._insert(point)
else:
self.child_right._insert(point)
return
def delete(self, point: Record) -> bool:
"""Delete a Record from the KDTree. May unbalance the KDTree"""
......@@ -91,12 +108,15 @@ class KDTree:
except ValueError:
return False
if getattr(point, self.variable) < self.partition_value:
return self.child_left.delete(point)
else:
return self.child_right.delete(point)
if getattr(point, self.variable) <= self.partition_value:
if self.child_left.delete(point):
return True
if getattr(point, self.variable) >= self.partition_value:
if self.child_right.delete(point):
return True
return False
def query(self, point) -> tuple[Record | None, float]:
def query(self, point) -> tuple[list[Record], float]:
"""Find the nearest Record within the KDTree to a query Record"""
if point.lon < 0:
point2 = Record(point.lon + 360, point.lat)
......@@ -106,32 +126,35 @@ class KDTree:
r1, d1 = self._query(point)
r2, d2 = self._query(point2)
if d1 <= d2:
r = r1
return r1, d1
else:
r = r2
return r, point.distance(r)
return r2, d2
def _query(
self,
point: Record,
current_best: Record | None = None,
current_best: list[Record] | None = None,
best_distance: float = inf,
) -> tuple[Record | None, float]:
) -> tuple[list[Record], float]:
if current_best is None:
current_best = list()
if not self.split:
for p in self.points:
dist = point.distance(p)
if dist < best_distance:
current_best = p
current_best = [p]
best_distance = dist
elif dist == best_distance:
current_best.append(p)
return current_best, best_distance
if getattr(point, self.variable) < self.partition_value:
if getattr(point, self.variable) <= self.partition_value:
current_best, best_distance = self.child_left._query(
point, current_best, best_distance
)
if (
point.distance(self._get_partition_record(point))
< best_distance
<= best_distance
):
current_best, best_distance = self.child_right._query(
point, current_best, best_distance
......@@ -142,7 +165,7 @@ class KDTree:
)
if (
point.distance(self._get_partition_record(point))
< best_distance
<= best_distance
):
current_best, best_distance = self.child_left._query(
point, current_best, best_distance
......
......@@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 11,
"id": "c60b30de-f864-477a-a09a-5f1caa4d9b9a",
"metadata": {},
"outputs": [
......@@ -56,18 +56,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(14222, 2)\n",
"(16000, 2)\n",
"shape: (5, 2)\n",
"┌──────┬─────┐\n",
"│ lon ┆ lat │\n",
"│ --- ┆ --- │\n",
"│ i64 ┆ i64 │\n",
"╞══════╪═════╡\n",
"│ -30 ┆ -41 │\n",
"│ -149 ┆ 56 │\n",
"│ 7 ┆ -68 │\n",
"│ -4883 │\n",
"│ -126 ┆ -35 │\n",
"│ 16 ┆ -75 │\n",
"│ 144 ┆ -77 │\n",
"│ -173 ┆ -83 │\n",
"│ 142-81 │\n",
"│ -50 ┆ -38 │\n",
"└──────┴─────┘\n"
]
}
......@@ -83,14 +83,14 @@
"# dates_use = dates.sample(N, with_replacement=True).alias(\"datetime\")\n",
"# uids = pl.Series(\"uid\", [generate_uid(8) for _ in range(N)])\n",
"\n",
"df = pl.DataFrame([lons_use, lats_use]).unique()\n",
"df = pl.DataFrame([lons_use, lats_use])\n",
"print(df.shape)\n",
"print(df.head())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 12,
"id": "875f2a67-49fe-476f-add1-b1d76c6cd8f9",
"metadata": {},
"outputs": [],
......@@ -100,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 13,
"id": "1e883e5a-5086-4c29-aff2-d308874eae16",
"metadata": {},
"outputs": [
......@@ -108,8 +108,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 43.5 ms, sys: 3.43 ms, total: 46.9 ms\n",
"Wall time: 46.8 ms\n"
"CPU times: user 82 ms, sys: 4.14 ms, total: 86.1 ms\n",
"Wall time: 84.3 ms\n"
]
}
],
......@@ -120,7 +120,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 14,
"id": "69022ad1-5ec8-4a09-836c-273ef452451f",
"metadata": {},
"outputs": [
......@@ -128,7 +128,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"173 μs ± 1.36 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
"188 μs ± 3.45 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
......@@ -140,7 +140,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 15,
"id": "28031966-c7d0-4201-a467-37590118e851",
"metadata": {},
"outputs": [
......@@ -148,7 +148,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"7.71 ms ± 38.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"8.72 ms ± 74.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
......@@ -160,7 +160,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 16,
"id": "0d10b2ba-57b2-475c-9d01-135363423990",
"metadata": {},
"outputs": [
......@@ -168,26 +168,27 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 15.4 s, sys: 37.8 ms, total: 15.5 s\n",
"Wall time: 15.5 s\n"
"CPU times: user 17.3 s, sys: 31.6 ms, total: 17.3 s\n",
"Wall time: 17.3 s\n"
]
}
],
"source": [
"%%time\n",
"n_samples = 1000\n",
"tol = 1e-8\n",
"test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]\n",
"kd_res = [kt.query(r) for r in test_records]\n",
"kd_recs = [_[0] for _ in kd_res]\n",
"kd_recs = [_[0][0] for _ in kd_res]\n",
"kd_dists = [_[1] for _ in kd_res]\n",
"tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]\n",
"tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n",
"assert kd_dists == tr_dists, \"NOT MATCHING?\""
"assert all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]), \"NOT MATCHING?\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 17,
"id": "a6aa6926-7fd5-4fff-bd20-7bc0305b948d",
"metadata": {},
"outputs": [
......@@ -213,7 +214,7 @@
"└──────────┴──────────┴─────────┴────────┴────────┴─────────┴────────┴────────┘"
]
},
"execution_count": 10,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
......@@ -228,7 +229,7 @@
"tr_lons = [r.lon for r in tr_recs]\n",
"tr_lats = [r.lat for r in tr_recs]\n",
"\n",
"pl.DataFrame({\n",
"df = pl.DataFrame({\n",
" \"test_lon\": test_lons, \n",
" \"test_lat\": test_lats,\n",
" \"kd_dist\": kd_dists,\n",
......@@ -237,7 +238,8 @@
" \"tr_dist\": tr_dists,\n",
" \"tr_lon\": tr_lons,\n",
" \"tr_lat\": tr_lats, \n",
"}).filter(pl.col(\"kd_dist\").ne(pl.col(\"tr_dist\")))"
"}).filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n",
"df"
]
}
],
......
......@@ -24,6 +24,18 @@ class TestKDTree(unittest.TestCase):
assert kt.delete(delete_rec)
assert delete_rec not in kt.child_right.child_right.points
def test_delete_dup(self):
test_records = [
Record(45, -23, uid="1"),
Record(45, -23, uid="2"),
Record(45, -23, uid="3"),
Record(45, -23, uid="4"),
]
kt = KDTree(test_records, max_depth=3)
assert kt.delete(test_records[1])
# TEST: Cannot delete same record twice!
assert not kt.delete(test_records[1])
def test_query(self):
kt = KDTree(self.records)
test_record = Record(-6, 35)
......@@ -33,7 +45,57 @@ class TestKDTree(unittest.TestCase):
true_record = self.records[true_ind]
self.assertAlmostEqual(true_dist, best_dist)
assert best_record == true_record
assert len(best_record) == 1
assert best_record[0] == true_record
def test_duplicated_pos(self):
# TEST: That equal records get partitioned equally
test_records = [
Record(45, -23, uid="1"),
Record(45, -23, uid="2"),
Record(45, -23, uid="3"),
Record(45, -23, uid="4"),
]
kt = KDTree(test_records, max_depth=3)
assert len(kt.child_left.child_left.points) == 1
assert len(kt.child_left.child_right.points) == 1
assert len(kt.child_right.child_left.points) == 1
assert len(kt.child_right.child_right.points) == 1
def test_insert_dup(self):
test_records = [
Record(45, -23, uid="1"),
Record(45, -23, uid="2"),
Record(45, -23, uid="3"),
Record(45, -23, uid="4"),
]
kt = KDTree(test_records, max_depth=3)
assert not kt.insert(test_records[0])
assert not kt.insert(test_records[1])
assert not kt.insert(test_records[2])
assert not kt.insert(test_records[3])
assert kt.insert(Record(45, -23, uid="5"))
assert not kt.insert(Record(45, -23, uid="5"))
# TEST: Can insert after deleting
assert kt.delete(Record(45, -23, uid="5"))
assert kt.insert(Record(45, -23, uid="5"))
def test_get_multiple_neighbours(self):
kt = KDTree(self.records)
kt.insert(Record(45, -21, uid="1"))
kt.insert(Record(45, -21, uid="2"))
r, d = kt.query(Record(44, -21, uid="3"))
assert len(r) == 2
def test_wrap(self):
# TEST: Accounts for wrap at -180, 180
kt = KDTree(self.records)
kt.insert(Record(-160, -64, uid="G"))
query_rec = Record(-178, -79, uid="E")
r, _ = kt.query(query_rec)
assert len(r) == 1
assert r[0].uid == "C"
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment