Commit 98f65e45 authored by Joseph Siddons's avatar Joseph Siddons
Browse files

chore: update and re-run KDTree notebook.

parent d58f82b9
...@@ -48,7 +48,7 @@ ...@@ -48,7 +48,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 11,
"id": "c60b30de-f864-477a-a09a-5f1caa4d9b9a", "id": "c60b30de-f864-477a-a09a-5f1caa4d9b9a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -56,18 +56,18 @@ ...@@ -56,18 +56,18 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"(14222, 2)\n", "(16000, 2)\n",
"shape: (5, 2)\n", "shape: (5, 2)\n",
"┌──────┬─────┐\n", "┌──────┬─────┐\n",
"│ lon ┆ lat │\n", "│ lon ┆ lat │\n",
"│ --- ┆ --- │\n", "│ --- ┆ --- │\n",
"│ i64 ┆ i64 │\n", "│ i64 ┆ i64 │\n",
"╞══════╪═════╡\n", "╞══════╪═════╡\n",
"│ -30 ┆ -41 │\n", "│ 16 ┆ -75 │\n",
"│ -149 ┆ 56 │\n", "│ 144 ┆ -77 │\n",
"│ 7 ┆ -68 │\n", "│ -173 ┆ -83 │\n",
"│ -4883 │\n", "│ 142-81 │\n",
"│ -126 ┆ -35 │\n", "│ -50 ┆ -38 │\n",
"└──────┴─────┘\n" "└──────┴─────┘\n"
] ]
} }
...@@ -83,14 +83,14 @@ ...@@ -83,14 +83,14 @@
"# dates_use = dates.sample(N, with_replacement=True).alias(\"datetime\")\n", "# dates_use = dates.sample(N, with_replacement=True).alias(\"datetime\")\n",
"# uids = pl.Series(\"uid\", [generate_uid(8) for _ in range(N)])\n", "# uids = pl.Series(\"uid\", [generate_uid(8) for _ in range(N)])\n",
"\n", "\n",
"df = pl.DataFrame([lons_use, lats_use]).unique()\n", "df = pl.DataFrame([lons_use, lats_use])\n",
"print(df.shape)\n", "print(df.shape)\n",
"print(df.head())" "print(df.head())"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 12,
"id": "875f2a67-49fe-476f-add1-b1d76c6cd8f9", "id": "875f2a67-49fe-476f-add1-b1d76c6cd8f9",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -100,7 +100,7 @@ ...@@ -100,7 +100,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 13,
"id": "1e883e5a-5086-4c29-aff2-d308874eae16", "id": "1e883e5a-5086-4c29-aff2-d308874eae16",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -108,8 +108,8 @@ ...@@ -108,8 +108,8 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"CPU times: user 43.5 ms, sys: 3.43 ms, total: 46.9 ms\n", "CPU times: user 82 ms, sys: 4.14 ms, total: 86.1 ms\n",
"Wall time: 46.8 ms\n" "Wall time: 84.3 ms\n"
] ]
} }
], ],
...@@ -120,7 +120,7 @@ ...@@ -120,7 +120,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 14,
"id": "69022ad1-5ec8-4a09-836c-273ef452451f", "id": "69022ad1-5ec8-4a09-836c-273ef452451f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -128,7 +128,7 @@ ...@@ -128,7 +128,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"173 μs ± 1.36 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" "188 μs ± 3.45 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
] ]
} }
], ],
...@@ -140,7 +140,7 @@ ...@@ -140,7 +140,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 15,
"id": "28031966-c7d0-4201-a467-37590118e851", "id": "28031966-c7d0-4201-a467-37590118e851",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -148,7 +148,7 @@ ...@@ -148,7 +148,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"7.71 ms ± 38.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" "8.72 ms ± 74.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
] ]
} }
], ],
...@@ -160,7 +160,7 @@ ...@@ -160,7 +160,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 16,
"id": "0d10b2ba-57b2-475c-9d01-135363423990", "id": "0d10b2ba-57b2-475c-9d01-135363423990",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -168,26 +168,27 @@ ...@@ -168,26 +168,27 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"CPU times: user 15.4 s, sys: 37.8 ms, total: 15.5 s\n", "CPU times: user 17.3 s, sys: 31.6 ms, total: 17.3 s\n",
"Wall time: 15.5 s\n" "Wall time: 17.3 s\n"
] ]
} }
], ],
"source": [ "source": [
"%%time\n", "%%time\n",
"n_samples = 1000\n", "n_samples = 1000\n",
"tol = 1e-8\n",
"test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]\n", "test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]\n",
"kd_res = [kt.query(r) for r in test_records]\n", "kd_res = [kt.query(r) for r in test_records]\n",
"kd_recs = [_[0] for _ in kd_res]\n", "kd_recs = [_[0][0] for _ in kd_res]\n",
"kd_dists = [_[1] for _ in kd_res]\n", "kd_dists = [_[1] for _ in kd_res]\n",
"tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]\n", "tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]\n",
"tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n", "tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n",
"assert kd_dists == tr_dists, \"NOT MATCHING?\"" "assert all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]), \"NOT MATCHING?\""
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 17,
"id": "a6aa6926-7fd5-4fff-bd20-7bc0305b948d", "id": "a6aa6926-7fd5-4fff-bd20-7bc0305b948d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -213,7 +214,7 @@ ...@@ -213,7 +214,7 @@
"└──────────┴──────────┴─────────┴────────┴────────┴─────────┴────────┴────────┘" "└──────────┴──────────┴─────────┴────────┴────────┴─────────┴────────┴────────┘"
] ]
}, },
"execution_count": 10, "execution_count": 17,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -228,7 +229,7 @@ ...@@ -228,7 +229,7 @@
"tr_lons = [r.lon for r in tr_recs]\n", "tr_lons = [r.lon for r in tr_recs]\n",
"tr_lats = [r.lat for r in tr_recs]\n", "tr_lats = [r.lat for r in tr_recs]\n",
"\n", "\n",
"pl.DataFrame({\n", "df = pl.DataFrame({\n",
" \"test_lon\": test_lons, \n", " \"test_lon\": test_lons, \n",
" \"test_lat\": test_lats,\n", " \"test_lat\": test_lats,\n",
" \"kd_dist\": kd_dists,\n", " \"kd_dist\": kd_dists,\n",
...@@ -237,7 +238,8 @@ ...@@ -237,7 +238,8 @@
" \"tr_dist\": tr_dists,\n", " \"tr_dist\": tr_dists,\n",
" \"tr_lon\": tr_lons,\n", " \"tr_lon\": tr_lons,\n",
" \"tr_lat\": tr_lats, \n", " \"tr_lat\": tr_lats, \n",
"}).filter(pl.col(\"kd_dist\").ne(pl.col(\"tr_dist\")))" "}).filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n",
"df"
] ]
} }
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment