diff --git a/notebooks/kdtree.ipynb b/notebooks/kdtree.ipynb index 9fc6ad23f0f984b7b1727ab132122e499e77346a..e23c16a230c713530899c9213281e1376296358b 100644 --- a/notebooks/kdtree.ipynb +++ b/notebooks/kdtree.ipynb @@ -20,6 +20,7 @@ "outputs": [], "source": [ "import os\n", + "\n", "os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n", "\n", "from datetime import datetime\n", @@ -174,9 +175,7 @@ " ).alias(\"_a\")\n", " )\n", " .with_columns(\n", - " (2 * R * (pl.col(\"_a\").sqrt().arcsin()))\n", - " .round(2)\n", - " .alias(\"_dist\")\n", + " (2 * R * (pl.col(\"_a\").sqrt().arcsin())).round(2).alias(\"_dist\")\n", " )\n", " .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n", " )\n", @@ -223,8 +222,7 @@ " check_cols(df, required_cols, \"df\")\n", "\n", " return (\n", - " df\n", - " .pipe(\n", + " df.pipe(\n", " haversine_df,\n", " lon=lon,\n", " lat=lat,\n", @@ -233,7 +231,7 @@ " )\n", " .filter(pl.col(\"_dist\").eq(pl.col(\"_dist\").min()))\n", " .drop([\"_dist\"])\n", - " )\n" + " )" ] }, { @@ -439,15 +437,15 @@ "test_records = [\n", " Record(\n", " random.choice(range(-179, 180)) + randnum(),\n", - " random.choice(range(-89, 90)) + randnum()\n", - " ) for _ in range(n_samples)\n", + " random.choice(range(-89, 90)) + randnum(),\n", + " )\n", + " for _ in range(n_samples)\n", "]\n", "kd_res = [kt.query(r) for r in test_records]\n", "kd_recs = [_[0][0] for _ in kd_res]\n", "kd_dists = [_[1] for _ in kd_res]\n", "tr_recs = [\n", - " records[np.argmin([r.distance(p) for p in records])]\n", - " for r in test_records\n", + " records[np.argmin([r.distance(p) for p in records])] for r in test_records\n", "]\n", "tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n", "\n", @@ -498,16 +496,18 @@ "tr_lons = [r.lon for r in tr_recs]\n", "tr_lats = [r.lat for r in tr_recs]\n", "\n", - "df = pl.DataFrame({\n", - " \"test_lon\": test_lons,\n", - " \"test_lat\": test_lats,\n", - " \"kd_dist\": kd_dists,\n", - " \"kd_lon\": kd_lons,\n", - " \"kd_lat\": kd_lats,\n", - " \"tr_dist\": tr_dists,\n", - " \"tr_lon\": tr_lons,\n", - " \"tr_lat\": tr_lats,\n", - "}).filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n", + "df = pl.DataFrame(\n", + " {\n", + " \"test_lon\": test_lons,\n", + " \"test_lat\": test_lats,\n", + " \"kd_dist\": kd_dists,\n", + " \"kd_lon\": kd_lons,\n", + " \"kd_lat\": kd_lats,\n", + " \"tr_dist\": tr_dists,\n", + " \"tr_lon\": tr_lons,\n", + " \"tr_lat\": tr_lats,\n", + " }\n", + ").filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n", "df" ] } diff --git a/notebooks/octtree.ipynb b/notebooks/octtree.ipynb index fac3bbbed731880958c9b560dc8869b77a292db3..a0422bc381c3bf4c73e44448c9fa6726f3c56d48 100644 --- a/notebooks/octtree.ipynb +++ b/notebooks/octtree.ipynb @@ -20,6 +20,7 @@ "outputs": [], "source": [ "import os\n", + "\n", "os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n", "\n", "from datetime import datetime, timedelta\n", @@ -33,7 +34,7 @@ "from GeoSpatialTools.octtree import (\n", " OctTree,\n", " SpaceTimeRecord as Record,\n", - " SpaceTimeRectangle as Rectangle\n", + " SpaceTimeRectangle as Rectangle,\n", ")" ] }, @@ -152,9 +153,7 @@ " ).alias(\"_a\")\n", " )\n", " .with_columns(\n", - " (2 * R * (pl.col(\"_a\").sqrt().arcsin()))\n", - " .round(2)\n", - " .alias(\"_dist\")\n", + " (2 * R * (pl.col(\"_a\").sqrt().arcsin())).round(2).alias(\"_dist\")\n", " )\n", " .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n", " )\n", @@ -244,8 +243,7 @@ " )\n", "\n", " return (\n", - " pool\n", - " .pipe(\n", + " pool.pipe(\n", " haversine_df,\n", " lon=lon,\n", " lat=lat,\n", @@ -254,7 +252,7 @@ " )\n", " .filter(pl.col(\"_dist\").le(max_dist))\n", " .drop([\"_dist\"])\n", - " )\n" + " )" ] }, { @@ -362,12 +360,14 @@ "source": [ "_df = df.clone()\n", "for i in range(100):\n", - " df2 = pl.DataFrame([\n", - " _df[\"lon\"].shuffle(),\n", - " _df[\"lat\"].shuffle(),\n", - " _df[\"datetime\"].shuffle(),\n", - " _df[\"uid\"].shuffle(),\n", - " ]).with_columns(\n", + " df2 = pl.DataFrame(\n", + " [\n", + " _df[\"lon\"].shuffle(),\n", + " _df[\"lat\"].shuffle(),\n", + " _df[\"datetime\"].shuffle(),\n", + " _df[\"uid\"].shuffle(),\n", + " ]\n", + " ).with_columns(\n", " pl.concat_str([pl.col(\"uid\"), pl.lit(f\"{i:03d}\")]).alias(\"uid\")\n", " )\n", " df = df.vstack(df2)\n", @@ -658,7 +658,16 @@ "dist = 250\n", "for _ in range(250):\n", " rec = Record(*random.choice(df.rows()))\n", - " orig = nearby_ships(lon=rec.lon, lat=rec.lat, dt=rec.datetime, max_dist=dist, dt_gap=dt, date_col=\"datetime\", pool=df, filter_datetime=True) # noqa\n", + " orig = nearby_ships(\n", + " lon=rec.lon,\n", + " lat=rec.lat,\n", + " dt=rec.datetime,\n", + " max_dist=dist,\n", + " dt_gap=dt,\n", + " date_col=\"datetime\",\n", + " pool=df,\n", + " filter_datetime=True,\n", + " ) # noqa\n", " tree = otree.nearby_points(rec, dist=dist, t_dist=dt)\n", " if orig.height > 0:\n", " if not tree:\n", @@ -666,7 +675,16 @@ " print(\"NO TREE!\")\n", " print(f\"{orig = }\")\n", " else:\n", - " tree = pl.from_records([(r.lon, r.lat, r.datetime, r.uid) for r in tree], orient=\"row\").rename({\"column_0\": \"lon\", \"column_1\": \"lat\", \"column_2\": \"datetime\", \"column_3\": \"uid\"}) # noqa\n", + " tree = pl.from_records(\n", + " [(r.lon, r.lat, r.datetime, r.uid) for r in tree], orient=\"row\"\n", + " ).rename(\n", + " {\n", + " \"column_0\": \"lon\",\n", + " \"column_1\": \"lat\",\n", + " \"column_2\": \"datetime\",\n", + " \"column_3\": \"uid\",\n", + " }\n", + " ) # noqa\n", " if tree.height != orig.height:\n", " print(\"Tree and Orig Heights Do Not Match\")\n", " print(f\"{orig = }\")\n",