Commit 37bb9314 authored by Joseph Siddons's avatar Joseph Siddons
Browse files

chore: format

parent 3edde364
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"\n",
"os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n", "os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n",
"\n", "\n",
"from datetime import datetime\n", "from datetime import datetime\n",
...@@ -174,9 +175,7 @@ ...@@ -174,9 +175,7 @@
" ).alias(\"_a\")\n", " ).alias(\"_a\")\n",
" )\n", " )\n",
" .with_columns(\n", " .with_columns(\n",
" (2 * R * (pl.col(\"_a\").sqrt().arcsin()))\n", " (2 * R * (pl.col(\"_a\").sqrt().arcsin())).round(2).alias(\"_dist\")\n",
" .round(2)\n",
" .alias(\"_dist\")\n",
" )\n", " )\n",
" .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n", " .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n",
" )\n", " )\n",
...@@ -223,8 +222,7 @@ ...@@ -223,8 +222,7 @@
" check_cols(df, required_cols, \"df\")\n", " check_cols(df, required_cols, \"df\")\n",
"\n", "\n",
" return (\n", " return (\n",
" df\n", " df.pipe(\n",
" .pipe(\n",
" haversine_df,\n", " haversine_df,\n",
" lon=lon,\n", " lon=lon,\n",
" lat=lat,\n", " lat=lat,\n",
...@@ -233,7 +231,7 @@ ...@@ -233,7 +231,7 @@
" )\n", " )\n",
" .filter(pl.col(\"_dist\").eq(pl.col(\"_dist\").min()))\n", " .filter(pl.col(\"_dist\").eq(pl.col(\"_dist\").min()))\n",
" .drop([\"_dist\"])\n", " .drop([\"_dist\"])\n",
" )\n" " )"
] ]
}, },
{ {
...@@ -439,15 +437,15 @@ ...@@ -439,15 +437,15 @@
"test_records = [\n", "test_records = [\n",
" Record(\n", " Record(\n",
" random.choice(range(-179, 180)) + randnum(),\n", " random.choice(range(-179, 180)) + randnum(),\n",
" random.choice(range(-89, 90)) + randnum()\n", " random.choice(range(-89, 90)) + randnum(),\n",
" ) for _ in range(n_samples)\n", " )\n",
" for _ in range(n_samples)\n",
"]\n", "]\n",
"kd_res = [kt.query(r) for r in test_records]\n", "kd_res = [kt.query(r) for r in test_records]\n",
"kd_recs = [_[0][0] for _ in kd_res]\n", "kd_recs = [_[0][0] for _ in kd_res]\n",
"kd_dists = [_[1] for _ in kd_res]\n", "kd_dists = [_[1] for _ in kd_res]\n",
"tr_recs = [\n", "tr_recs = [\n",
" records[np.argmin([r.distance(p) for p in records])]\n", " records[np.argmin([r.distance(p) for p in records])] for r in test_records\n",
" for r in test_records\n",
"]\n", "]\n",
"tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n", "tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n",
"\n", "\n",
...@@ -498,16 +496,18 @@ ...@@ -498,16 +496,18 @@
"tr_lons = [r.lon for r in tr_recs]\n", "tr_lons = [r.lon for r in tr_recs]\n",
"tr_lats = [r.lat for r in tr_recs]\n", "tr_lats = [r.lat for r in tr_recs]\n",
"\n", "\n",
"df = pl.DataFrame({\n", "df = pl.DataFrame(\n",
" \"test_lon\": test_lons,\n", " {\n",
" \"test_lat\": test_lats,\n", " \"test_lon\": test_lons,\n",
" \"kd_dist\": kd_dists,\n", " \"test_lat\": test_lats,\n",
" \"kd_lon\": kd_lons,\n", " \"kd_dist\": kd_dists,\n",
" \"kd_lat\": kd_lats,\n", " \"kd_lon\": kd_lons,\n",
" \"tr_dist\": tr_dists,\n", " \"kd_lat\": kd_lats,\n",
" \"tr_lon\": tr_lons,\n", " \"tr_dist\": tr_dists,\n",
" \"tr_lat\": tr_lats,\n", " \"tr_lon\": tr_lons,\n",
"}).filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n", " \"tr_lat\": tr_lats,\n",
" }\n",
").filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n",
"df" "df"
] ]
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"\n",
"os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n", "os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n",
"\n", "\n",
"from datetime import datetime, timedelta\n", "from datetime import datetime, timedelta\n",
...@@ -33,7 +34,7 @@ ...@@ -33,7 +34,7 @@
"from GeoSpatialTools.octtree import (\n", "from GeoSpatialTools.octtree import (\n",
" OctTree,\n", " OctTree,\n",
" SpaceTimeRecord as Record,\n", " SpaceTimeRecord as Record,\n",
" SpaceTimeRectangle as Rectangle\n", " SpaceTimeRectangle as Rectangle,\n",
")" ")"
] ]
}, },
...@@ -152,9 +153,7 @@ ...@@ -152,9 +153,7 @@
" ).alias(\"_a\")\n", " ).alias(\"_a\")\n",
" )\n", " )\n",
" .with_columns(\n", " .with_columns(\n",
" (2 * R * (pl.col(\"_a\").sqrt().arcsin()))\n", " (2 * R * (pl.col(\"_a\").sqrt().arcsin())).round(2).alias(\"_dist\")\n",
" .round(2)\n",
" .alias(\"_dist\")\n",
" )\n", " )\n",
" .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n", " .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n",
" )\n", " )\n",
...@@ -244,8 +243,7 @@ ...@@ -244,8 +243,7 @@
" )\n", " )\n",
"\n", "\n",
" return (\n", " return (\n",
" pool\n", " pool.pipe(\n",
" .pipe(\n",
" haversine_df,\n", " haversine_df,\n",
" lon=lon,\n", " lon=lon,\n",
" lat=lat,\n", " lat=lat,\n",
...@@ -254,7 +252,7 @@ ...@@ -254,7 +252,7 @@
" )\n", " )\n",
" .filter(pl.col(\"_dist\").le(max_dist))\n", " .filter(pl.col(\"_dist\").le(max_dist))\n",
" .drop([\"_dist\"])\n", " .drop([\"_dist\"])\n",
" )\n" " )"
] ]
}, },
{ {
...@@ -362,12 +360,14 @@ ...@@ -362,12 +360,14 @@
"source": [ "source": [
"_df = df.clone()\n", "_df = df.clone()\n",
"for i in range(100):\n", "for i in range(100):\n",
" df2 = pl.DataFrame([\n", " df2 = pl.DataFrame(\n",
" _df[\"lon\"].shuffle(),\n", " [\n",
" _df[\"lat\"].shuffle(),\n", " _df[\"lon\"].shuffle(),\n",
" _df[\"datetime\"].shuffle(),\n", " _df[\"lat\"].shuffle(),\n",
" _df[\"uid\"].shuffle(),\n", " _df[\"datetime\"].shuffle(),\n",
" ]).with_columns(\n", " _df[\"uid\"].shuffle(),\n",
" ]\n",
" ).with_columns(\n",
" pl.concat_str([pl.col(\"uid\"), pl.lit(f\"{i:03d}\")]).alias(\"uid\")\n", " pl.concat_str([pl.col(\"uid\"), pl.lit(f\"{i:03d}\")]).alias(\"uid\")\n",
" )\n", " )\n",
" df = df.vstack(df2)\n", " df = df.vstack(df2)\n",
...@@ -658,7 +658,16 @@ ...@@ -658,7 +658,16 @@
"dist = 250\n", "dist = 250\n",
"for _ in range(250):\n", "for _ in range(250):\n",
" rec = Record(*random.choice(df.rows()))\n", " rec = Record(*random.choice(df.rows()))\n",
" orig = nearby_ships(lon=rec.lon, lat=rec.lat, dt=rec.datetime, max_dist=dist, dt_gap=dt, date_col=\"datetime\", pool=df, filter_datetime=True) # noqa\n", " orig = nearby_ships(\n",
" lon=rec.lon,\n",
" lat=rec.lat,\n",
" dt=rec.datetime,\n",
" max_dist=dist,\n",
" dt_gap=dt,\n",
" date_col=\"datetime\",\n",
" pool=df,\n",
" filter_datetime=True,\n",
" ) # noqa\n",
" tree = otree.nearby_points(rec, dist=dist, t_dist=dt)\n", " tree = otree.nearby_points(rec, dist=dist, t_dist=dt)\n",
" if orig.height > 0:\n", " if orig.height > 0:\n",
" if not tree:\n", " if not tree:\n",
...@@ -666,7 +675,16 @@ ...@@ -666,7 +675,16 @@
" print(\"NO TREE!\")\n", " print(\"NO TREE!\")\n",
" print(f\"{orig = }\")\n", " print(f\"{orig = }\")\n",
" else:\n", " else:\n",
" tree = pl.from_records([(r.lon, r.lat, r.datetime, r.uid) for r in tree], orient=\"row\").rename({\"column_0\": \"lon\", \"column_1\": \"lat\", \"column_2\": \"datetime\", \"column_3\": \"uid\"}) # noqa\n", " tree = pl.from_records(\n",
" [(r.lon, r.lat, r.datetime, r.uid) for r in tree], orient=\"row\"\n",
" ).rename(\n",
" {\n",
" \"column_0\": \"lon\",\n",
" \"column_1\": \"lat\",\n",
" \"column_2\": \"datetime\",\n",
" \"column_3\": \"uid\",\n",
" }\n",
" ) # noqa\n",
" if tree.height != orig.height:\n", " if tree.height != orig.height:\n",
" print(\"Tree and Orig Heights Do Not Match\")\n", " print(\"Tree and Orig Heights Do Not Match\")\n",
" print(f\"{orig = }\")\n", " print(f\"{orig = }\")\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment