diff --git a/notebooks/kdtree.ipynb b/notebooks/kdtree.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9a69a4ef9e5990b8a55f33dc23190f0d673dc285 --- /dev/null +++ b/notebooks/kdtree.ipynb @@ -0,0 +1,265 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "bdfa1141-8ae0-499b-8355-927759af69d1", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import gzip\n", + "os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n", + "\n", + "from datetime import datetime, timedelta\n", + "from random import choice\n", + "from string import ascii_letters, digits\n", + "import random\n", + "import inspect\n", + "\n", + "import polars as pl\n", + "import numpy as np\n", + "\n", + "from GeoSpatialTools import Record, haversine, KDTree" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8711862a-6295-43eb-ac51-333fda638ef4", + "metadata": {}, + "outputs": [], + "source": [ + "def randnum() -> float:\n", + " return 2 * (np.random.rand() - 0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "72164093-fac1-4dfc-803b-6522cc9a4d62", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_uid(n: int) -> str:\n", + " chars = ascii_letters + digits\n", + " return \"\".join(random.choice(chars) for _ in range(n))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c60b30de-f864-477a-a09a-5f1caa4d9b9a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(14222, 2)\n", + "shape: (5, 2)\n", + "┌──────┬─────â”\n", + "│ lon ┆ lat │\n", + "│ --- ┆ --- │\n", + "│ i64 ┆ i64 │\n", + "╞â•â•â•â•â•â•╪â•â•â•â•â•â•¡\n", + "│ -30 ┆ -41 │\n", + "│ -149 ┆ 56 │\n", + "│ 7 ┆ -68 │\n", + "│ -48 ┆ 83 │\n", + "│ -126 ┆ -35 │\n", + "└──────┴─────┘\n" + ] + } + ], + "source": [ + "N = 16_000\n", + "lons = pl.int_range(-180, 180, eager=True)\n", + "lats = pl.int_range(-90, 90, eager=True)\n", + "dates = pl.datetime_range(datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23), interval=\"1h\", eager=True)\n", + "\n", + "lons_use = lons.sample(N, with_replacement=True).alias(\"lon\")\n", + "lats_use = lats.sample(N, with_replacement=True).alias(\"lat\")\n", + "# dates_use = dates.sample(N, with_replacement=True).alias(\"datetime\")\n", + "# uids = pl.Series(\"uid\", [generate_uid(8) for _ in range(N)])\n", + "\n", + "df = pl.DataFrame([lons_use, lats_use]).unique()\n", + "print(df.shape)\n", + "print(df.head())" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "875f2a67-49fe-476f-add1-b1d76c6cd8f9", + "metadata": {}, + "outputs": [], + "source": [ + "records = [Record(**r) for r in df.rows(named=True)]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1e883e5a-5086-4c29-aff2-d308874eae16", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 43.5 ms, sys: 3.43 ms, total: 46.9 ms\n", + "Wall time: 46.8 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "kt = KDTree(records)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "69022ad1-5ec8-4a09-836c-273ef452451f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "173 μs ± 1.36 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())\n", + "kt.query(test_record)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "28031966-c7d0-4201-a467-37590118e851", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7.71 ms ± 38.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())\n", + "np.argmin([test_record.distance(p) for p in records])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "0d10b2ba-57b2-475c-9d01-135363423990", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 15.4 s, sys: 37.8 ms, total: 15.5 s\n", + "Wall time: 15.5 s\n" + ] + } + ], + "source": [ + "%%time\n", + "n_samples = 1000\n", + "test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]\n", + "kd_res = [kt.query(r) for r in test_records]\n", + "kd_recs = [_[0] for _ in kd_res]\n", + "kd_dists = [_[1] for _ in kd_res]\n", + "tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]\n", + "tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n", + "assert kd_dists == tr_dists, \"NOT MATCHING?\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a6aa6926-7fd5-4fff-bd20-7bc0305b948d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div><style>\n", + ".dataframe > thead > tr,\n", + ".dataframe > tbody > tr {\n", + " text-align: right;\n", + " white-space: pre-wrap;\n", + "}\n", + "</style>\n", + "<small>shape: (0, 8)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_lon</th><th>test_lat</th><th>kd_dist</th><th>kd_lon</th><th>kd_lat</th><th>tr_dist</th><th>tr_lon</th><th>tr_lat</th></tr><tr><td>f64</td><td>f64</td><td>f64</td><td>i64</td><td>i64</td><td>f64</td><td>i64</td><td>i64</td></tr></thead><tbody></tbody></table></div>" + ], + "text/plain": [ + "shape: (0, 8)\n", + "┌──────────┬──────────┬─────────┬────────┬────────┬─────────┬────────┬────────â”\n", + "│ test_lon ┆ test_lat ┆ kd_dist ┆ kd_lon ┆ kd_lat ┆ tr_dist ┆ tr_lon ┆ tr_lat │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ f64 ┆ f64 ┆ f64 ┆ i64 ┆ i64 ┆ f64 ┆ i64 ┆ i64 │\n", + "╞â•â•â•â•â•â•â•â•â•â•╪â•â•â•â•â•â•â•â•â•â•╪â•â•â•â•â•â•â•â•â•╪â•â•â•â•â•â•â•â•╪â•â•â•â•â•â•â•â•╪â•â•â•â•â•â•â•â•â•╪â•â•â•â•â•â•â•â•╪â•â•â•â•â•â•â•â•â•¡\n", + "└──────────┴──────────┴─────────┴────────┴────────┴─────────┴────────┴────────┘" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_lons = [r.lon for r in test_records]\n", + "test_lats = [r.lat for r in test_records]\n", + "\n", + "kd_lons = [r.lon for r in kd_recs]\n", + "kd_lats = [r.lat for r in kd_recs]\n", + "\n", + "tr_lons = [r.lon for r in tr_recs]\n", + "tr_lats = [r.lat for r in tr_recs]\n", + "\n", + "pl.DataFrame({\n", + " \"test_lon\": test_lons, \n", + " \"test_lat\": test_lats,\n", + " \"kd_dist\": kd_dists,\n", + " \"kd_lon\": kd_lons,\n", + " \"kd_lat\": kd_lats,\n", + " \"tr_dist\": tr_dists,\n", + " \"tr_lon\": tr_lons,\n", + " \"tr_lat\": tr_lats, \n", + "}).filter(pl.col(\"kd_dist\").ne(pl.col(\"tr_dist\")))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "GeoSpatialTools", + "language": "python", + "name": "geospatialtools" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/octtree.ipynb b/notebooks/octtree.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2d9339801c111896f5ac3b56b91ee6f035b02998 --- /dev/null +++ b/notebooks/octtree.ipynb @@ -0,0 +1,707 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c70ce35d-6112-4c12-9387-9c788c84a8e9", + "metadata": {}, + "source": [ + "## OctTree!\n", + "\n", + "Testing the time to look-up nearby records with the PyCOADS OctTree implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c0956916-f50a-444d-a5b6-f06d3fb9b44d", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import gzip\n", + "os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n", + "\n", + "from datetime import datetime, timedelta\n", + "from random import choice\n", + "from string import ascii_letters, digits\n", + "import random\n", + "import inspect\n", + "\n", + "import polars as pl\n", + "import numpy as np\n", + "\n", + "from GeoSpatialTools.octtree import OctTree, SpaceTimeRecord as Record, SpaceTimeRectangle as Rectangle" + ] + }, + { + "cell_type": "raw", + "id": "99295bad-0db3-444b-8d38-acc7875cc0f0", + "metadata": {}, + "source": [ + "## Generate Data\n", + "\n", + "16,000 rows of data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d8f1e5e1-513c-4bdf-a9f9-cef9562a7cb7", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_uid(n: int) -> str:\n", + " chars = ascii_letters + digits\n", + " return \"\".join(random.choice(chars) for _ in range(n))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "986d9cc5-e610-449a-9ee7-e281b7558ca9", + "metadata": {}, + "outputs": [], + "source": [ + "N = 16_000\n", + "lons = pl.int_range(-180, 180, eager=True)\n", + "lats = pl.int_range(-90, 90, eager=True)\n", + "dates = pl.datetime_range(datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23), interval=\"1h\", eager=True)\n", + "\n", + "lons_use = lons.sample(N, with_replacement=True).alias(\"lon\")\n", + "lats_use = lats.sample(N, with_replacement=True).alias(\"lat\")\n", + "dates_use = dates.sample(N, with_replacement=True).alias(\"datetime\")\n", + "uids = pl.Series(\"uid\", [generate_uid(8) for _ in range(N)])\n", + "\n", + "df = pl.DataFrame([lons_use, lats_use, dates_use, uids]).unique()" + ] + }, + { + "cell_type": "markdown", + "id": "237096f1-093e-49f0-9a9a-2bec5231726f", + "metadata": {}, + "source": [ + "## Add extra rows\n", + "\n", + "For testing larger datasets. Uncomment to use." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0b8fd425-8a90-4f76-91b7-60df48aa98e4", + "metadata": {}, + "outputs": [], + "source": [ + "# _df = df.clone()\n", + "# for i in range(100):\n", + "# df2 = pl.DataFrame([\n", + "# _df[\"lon\"].shuffle(),\n", + "# _df[\"lat\"].shuffle(),\n", + "# _df[\"datetime\"].shuffle(),\n", + "# _df[\"uid\"].shuffle(),\n", + "# ]).with_columns(pl.concat_str([pl.col(\"uid\"), pl.lit(f\"{i:03d}\")]).alias(\"uid\"))\n", + "# df = df.vstack(df2)\n", + "# df.shape\n", + "# df" + ] + }, + { + "cell_type": "markdown", + "id": "c7bd16e0-96a6-426b-b00a-7c3b8a2aaddd", + "metadata": {}, + "source": [ + "## Intialise the OctTree Object" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "af06a976-ff52-49e0-a886-91bcbe540ffe", + "metadata": {}, + "outputs": [], + "source": [ + "otree = OctTree(Rectangle(0, 0, datetime(1900, 1, 16), 360, 180, timedelta(days=32)), capacity = 10, max_depth = 25)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2ba99b37-787c-4862-8075-a7596208c60e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 186 ms, sys: 191 ms, total: 377 ms\n", + "Wall time: 118 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "for r in df.rows():\n", + " otree.insert(Record(*r))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "59d38446-f7d2-4eec-bba3-c39bd7279623", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OctTree:\n", + "- boundary: SpaceTimeRectangle(x = 0, y = 0, w = 360, h = 180, t = 1900-01-16 00:00:00, dt = 32 days, 0:00:00)\n", + "- capacity: 10\n", + "- depth: 0\n", + "- max_depth: 25\n", + "- contents:\n", + "- number of elements: 10\n", + " * Record(x = 43, y = -68, datetime = 1900-01-08 13:00:00, uid = OBiqSYcn)\n", + " * Record(x = 97, y = -47, datetime = 1900-01-02 14:00:00, uid = w589k3Oe)\n", + " * Record(x = -68, y = 44, datetime = 1900-01-30 11:00:00, uid = XAaA7McU)\n", + " * Record(x = -170, y = 77, datetime = 1900-01-19 09:00:00, uid = x6eLi65N)\n", + " * Record(x = -2, y = 7, datetime = 1900-01-12 09:00:00, uid = CjB2Pglt)\n", + " * Record(x = -175, y = 65, datetime = 1900-01-15 01:00:00, uid = bTB9DkDI)\n", + " * Record(x = 8, y = 83, datetime = 1900-01-04 10:00:00, uid = aYCKIBl9)\n", + " * Record(x = 20, y = 60, datetime = 1900-01-24 16:00:00, uid = 8GsD19WF)\n", + " * Record(x = 161, y = 40, datetime = 1900-01-24 20:00:00, uid = FIfAABuC)\n", + " * Record(x = -69, y = -9, datetime = 1900-01-11 11:00:00, uid = uTcS5D4e)\n", + "- with children:\n", + " OctTree:\n", + " - boundary: SpaceTimeRectangle(x = -90.0, y = 45.0, w = 180.0, h = 90.0, t = 1900-01-08 00:00:00, dt = 16 days, 0:00:00)\n", + " - capacity: 10\n", + " - depth: 1\n", + " - max_depth: 25\n", + " - contents:\n", + " - number of elements: 10\n", + " * Record(x = -156, y = 57, datetime = 1900-01-08 10:00:00, uid = aFheRU2n)\n", + " * Record(x = -100, y = 61, datetime = 1900-01-15 09:00:00, uid = Sa1iavle)\n", + " * Record(x = -168, y = 88, datetime = 1900-01-03 07:00:00, uid = IlYKGW0N)\n", + " * Record(x = -80, y = 50, datetime = 1900-01-05 09:00:00, uid = Rg3GHM4d)\n", + " * Record(x = -92, y = 39, datetime = 1900-01-15 06:00:00, uid = u804YMFB)\n", + " * Record(x = -119, y = 60, datetime = 1900-01-12 22:00:00, uid = vdEPjkib)\n", + " * Record(x = -160, y = 79, datetime = 1900-01-06 08:00:00, uid = QmrPEL6h)\n", + " * Record(x = -95, y = 21, datetime = 1900-01-09 04:00:00, uid = hfjTKSCH)\n", + " * Record(x = -93, y = 61, datetime = 1900-01-09 20:00:00, uid = SzIrja9S)\n", + " * Record(x = -149, y = 34, datetime = 1900-01-05 05:00:00, uid = b02MxQjV)\n", + " - with children:\n", + " OctTree:\n", + " - boundary: SpaceTimeRectangle(x = -135.0, y = 67.5, w = 90.0, h = 45.0, t = 1900-01-04 00:00:00, dt = 8 days, 0:00:00)\n", + " - capacity: 10\n", + " - depth: 2\n", + " - max_depth: 25\n", + " - contents:\n", + " - number of elements: 10\n", + " * Record(x = -134, y = 79, datetime = 1900-01-05 14:00:00, uid = 7Q0FKGMk)\n", + " * Record(x = -90, y = 53, datetime = 1900-01-05 03:00:00, uid = LLx7iz2v)\n", + " * Record(x = -176, y = 50, datetime = 1900-01-06 20:00:00, uid = x6K5DlTl)\n", + " * Record(x = -141, y = 52, datetime = 1900-01-02 15:00:00, uid = xTpGPaEy)\n", + " * Record(x = -116, y = 68, datetime = 1900-01-05 16:00:00, uid = eECSkpdU)\n", + " * Record(x = -138, y = 63, datetime = 1900-01-05 02:00:00, uid = Ftf9uhH3)\n", + " * Record(x = -173, y = 71, datetime = 1900-01-03 03:00:00, uid = mu3vwHM5)\n", + " * Record(x = -148, y = 49, datetime = 1900-01-05 15:00:00, uid = 8DFDI3CJ)\n", + " * Record(x = -157, y = 63, datetime = 1900-01-06 19:00:00, uid = mVqLntgh)\n", + " * Record(x = -154, y = 45, datetime = 1900-01-07 11:00:00, uid = 1UoA1NNC)\n", + " - with children:\n", + " OctTree:\n", + " - boundary: SpaceTimeRectangle(x = -157.5, y = 78.75, w = 45.0, h = 22.5, t = 1900-01-02 00:00:00, dt = 4 days, 0:00:00)\n", + " - capacity: 10\n", + " - depth: 3\n", + " - max_depth: 25\n", + " - contents:\n", + " - number of elements: 10\n", + " * Record(x = -147, y = 83, datetime = 1900-01-01 18:00:00, uid = WaO5R7fy)\n", + " * Record(x = -136, y = 72, datetime = 1900-01-02 03:00:00, uid = OWaMqULr)\n", + " * Record(x = -176, y = 79, datetime = 1900-01-02 06:00:00, uid = NTjvqz2c)\n", + " * Record(x = -152, y = 72, datetime = 1900-01-03 18:00:00, uid = 7rtQIGtn)\n", + " * Record(x = -162, y = 78, datetime = 1900-01-02 04:00:00, uid = Wi9RsOIX)\n", + " * Record(x = -136, y = 79, datetime = 1900-01-01 11:00:00, uid = hSltzeuH)\n", + " * Record(x = -176, y = 89, datetime = 1900-01-02 09:00:00, uid = cOLgAely)\n", + " * Record(x = -141, y = 75, datetime = 1900-01-03 23:00:00, uid = gH755dC3)\n", + " * Record(x = -158, y = 72, datetime = 1900-01-02 23:00:00, uid = NUmMfw9K)\n", + " * Record(x = -168, y = 72, datetime = 1900-01-02 01:00:00, uid = ZFcsxYG4)\n", + " - with children:\n", + " OctTree:\n", + " - boundary: SpaceTimeRectangle(x = -168.75, y = 84.375, w = 22.5, h = 11.25, t = 1900-01-01 00:00:00, dt = 2 days, 0:00:00)\n", + " - capacity: 10\n", + " - depth: 4\n", + " - max_depth: 25\n", + " - contents:\n", + " - number of elements: 6\n", + " * Record(x = -158, y = 86, datetime = 1900-01-01 15:00:00, uid = DOD5jT2l)\n", + " * Record(x = -165, y = 88, datetime = 1900-01-01 13:00:00, uid = kdGlzz41)\n", + " * Record(x = -173, y = 82, datetime = 1900-01-01 04:00:00, uid = aWBwIP4U)\n", + " * Record(x = -180, y = 89, datetime = 1900-01-01 22:00:00, uid = HOxbaCm8)\n", + " * Record(x = -165, y = 81, datetime = 1900-01-01 16:00:00, uid = JtRn9y9e)\n", + " * Record(x = -164, y = 84, datetime = 1900-01-01 03:00:00, uid = vELpx1ij)\n", + " OctTree:\n", + " - boundary: SpaceTimeRectangle(x = -146.25, y = 84.375, w = 22.5, h = 11.25, t = 1900-01-01 00:00:00, dt = 2 days, 0:00:00)\n", + " - capacity: 10\n", + " - depth: 4\n", + " - max_depth: 25\n", + " - contents:\n", + " - number of elements: 1\n", + " * Record(x = -157, y = 84, datetime = 1900-01-01 17:00:00, uid = 6DlgVOXg)\n", + " OctTree:\n", + " - boundary: SpaceTimeRectangle(x = -168.75, y = 73.125, w = 22.5, h = 11.25, t = 1900-01-01 00:00:00, dt = 2 days, 0:00:00)\n", + " - capacity: 10\n", + " - depth: 4\n", + " - max_depth: 25\n", + " - contents:\n", + " - number of elements: 2\n" + ] + } + ], + "source": [ + "s = str(otree)\n", + "print(\"\\n\".join(s.split(\"\\n\")[:100]))" + ] + }, + { + "cell_type": "markdown", + "id": "6b02c2ea-6566-47c2-97e0-43d8b18e0713", + "metadata": {}, + "source": [ + "## Time Execution\n", + "\n", + "Testing the identification of nearby points against the original full search" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "094b588c-e938-4838-9719-1defdfff74fa", + "metadata": {}, + "outputs": [], + "source": [ + "dts = pl.datetime_range(datetime(1900, 1, 1), datetime(1900, 2, 1), interval=\"1h\", eager=True, closed=\"left\")\n", + "N = dts.len()\n", + "lons = 180 - 360 * np.random.rand(N)\n", + "lats = 90 - 180 * np.random.rand(N)\n", + "test_df = pl.DataFrame({\"lon\": lons, \"lat\": lats, \"datetime\": dts})\n", + "test_recs = [Record(*r) for r in test_df.rows()]\n", + "dt = timedelta(days = 1)\n", + "dist = 350" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "66a48b86-d449-45d2-9837-2b3e07f5563d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "206 μs ± 3.36 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "otree.nearby_points(random.choice(test_recs), dist=dist, t_dist=dt)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "972d4a16-39fd-4f80-8592-1c5d5cabf5be", + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "outputs": [], + "source": [ + "def check_cols(\n", + " df: pl.DataFrame | pl.LazyFrame,\n", + " cols: list[str],\n", + " var_name: str = \"dataframe\",\n", + ") -> None:\n", + " \"\"\"\n", + " Check that a dataframe contains a list of columns. Raises an error if not.\n", + "\n", + " Parameters\n", + " ----------\n", + " df : polars Frame\n", + " Dataframe to check\n", + " cols : list[str]\n", + " Requried columns\n", + " var_name : str\n", + " Name of the Frame - used for displaying in any error.\n", + " \"\"\"\n", + " calling_func = inspect.stack()[1][3]\n", + " if isinstance(df, pl.DataFrame):\n", + " have_cols = df.columns\n", + " elif isinstance(df, pl.LazyFrame):\n", + " have_cols = df.collect_schema().names()\n", + " else:\n", + " raise TypeError(\"Input Frame is not a polars Frame\")\n", + "\n", + " cols_in_frame = intersect(cols, have_cols)\n", + " missing = [c for c in cols if c not in cols_in_frame]\n", + "\n", + " if len(missing) > 0:\n", + " err_str = f\"({calling_func}) - {var_name} missing required columns. \"\n", + " err_str += f'Require: {\", \".join(cols)}. '\n", + " err_str += f'Missing: {\", \".join(missing)}.'\n", + " logging.error(err_str)\n", + " raise ValueError(err_str)\n", + "\n", + " return\n", + "\n", + "\n", + "def haversine_df(\n", + " df: pl.DataFrame | pl.LazyFrame,\n", + " date_var: str = \"datetime\",\n", + " R: float = 6371,\n", + " reverse: bool = False,\n", + " out_colname: str = \"dist\",\n", + " lon_col: str = \"lon\",\n", + " lat_col: str = \"lat\",\n", + " lon2_col: str | None = None,\n", + " lat2_col: str | None = None,\n", + " sorted: bool = False,\n", + " rev_prefix: str = \"rev_\",\n", + ") -> pl.DataFrame | pl.LazyFrame:\n", + " \"\"\"\n", + " Compute haversine distance on earth surface between lon-lat positions.\n", + "\n", + " If only 'lon_col' and 'lat_col' are specified then this computes the\n", + " distance between consecutive points. If a second set of positions is\n", + " included via the optional 'lon2_col' and 'lat2_col' arguments then the\n", + " distances between the columns are computed.\n", + "\n", + " Parameters\n", + " ----------\n", + " df : polars.DataFrame\n", + " The data, containing required columns:\n", + " * lon_col\n", + " * lat_col\n", + " * date_var\n", + " date_var : str\n", + " Name of the datetime column on which to sort the positions\n", + " R : float\n", + " Radius of earth in km\n", + " reverse : bool\n", + " Compute distances in reverse\n", + " out_colname : str\n", + " Name of the output column to store distances. Prefixed with 'rev_' if\n", + " reverse is True\n", + " lon_col : str\n", + " Name of the longitude column\n", + " lat_col : str\n", + " Name of the latitude column\n", + " lon2_col : str\n", + " Name of the 2nd longitude column if present\n", + " lat2_col : str\n", + " Name of the 2nd latitude column if present\n", + " sorted : bool\n", + " Compute distances assuming that the frame is already sorted\n", + " rev_prefix : str\n", + " Prefix to use for colnames if reverse is True\n", + "\n", + " Returns\n", + " -------\n", + " polars.DataFrame\n", + " With additional column specifying distances between consecutive points\n", + " in the same units as 'R'. With colname defined by 'out_colname'.\n", + " \"\"\"\n", + " required_cols = [lon_col, lat_col]\n", + "\n", + " if lon2_col is not None and lat2_col is not None:\n", + " required_cols += [lon2_col, lat2_col]\n", + " check_cols(df, required_cols, \"df\")\n", + " return (\n", + " df.with_columns(\n", + " [\n", + " pl.col(lat_col).radians().alias(\"_lat0\"),\n", + " pl.col(lat2_col).radians().alias(\"_lat1\"),\n", + " (pl.col(lon_col) - pl.col(lon2_col))\n", + " .radians()\n", + " .alias(\"_dlon\"),\n", + " (pl.col(lat_col) - pl.col(lat2_col))\n", + " .radians()\n", + " .alias(\"_dlat\"),\n", + " ]\n", + " )\n", + " .with_columns(\n", + " (\n", + " (pl.col(\"_dlat\") / 2).sin().pow(2)\n", + " + pl.col(\"_lat0\").cos()\n", + " * pl.col(\"_lat1\").cos()\n", + " * (pl.col(\"_dlon\") / 2).sin().pow(2)\n", + " ).alias(\"_a\")\n", + " )\n", + " .with_columns(\n", + " (2 * R * (pl.col(\"_a\").sqrt().arcsin()))\n", + " .round(2)\n", + " .alias(out_colname)\n", + " )\n", + " .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n", + " )\n", + "\n", + " if lon2_col is not None or lat2_col is not None:\n", + " logging.warning(\n", + " \"(haversine_df) 2nd position incorrectly specified. \"\n", + " + \"Calculating consecutive distances.\"\n", + " )\n", + "\n", + " required_cols += [date_var]\n", + " check_cols(df, required_cols, \"df\")\n", + " if reverse:\n", + " out_colname = rev_prefix + out_colname\n", + " if not sorted:\n", + " df = df.sort(date_var, descending=reverse)\n", + " return (\n", + " df.with_columns(\n", + " [\n", + " pl.col(lat_col).radians().alias(\"_lat0\"),\n", + " pl.col(lat_col).shift(n=-1).radians().alias(\"_lat1\"),\n", + " (pl.col(lon_col).shift(n=-1) - pl.col(lon_col))\n", + " .radians()\n", + " .alias(\"_dlon\"),\n", + " (pl.col(lat_col).shift(n=-1) - pl.col(lat_col))\n", + " .radians()\n", + " .alias(\"_dlat\"),\n", + " ]\n", + " )\n", + " .with_columns(\n", + " (\n", + " (pl.col(\"_dlat\") / 2).sin().pow(2)\n", + " + pl.col(\"_lat0\").cos()\n", + " * pl.col(\"_lat1\").cos()\n", + " * (pl.col(\"_dlon\") / 2).sin().pow(2)\n", + " ).alias(\"_a\")\n", + " )\n", + " .with_columns(\n", + " (2 * R * (pl.col(\"_a\").sqrt().arcsin()))\n", + " .round(2)\n", + " .fill_null(strategy=\"forward\")\n", + " .alias(out_colname)\n", + " )\n", + " .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n", + " )\n", + "\n", + "def intersect(a, b) -> set:\n", + " return set(a) & set(b)\n", + "\n", + "def nearby_ships(\n", + " lon: float,\n", + " lat: float,\n", + " pool: pl.DataFrame,\n", + " max_dist: float,\n", + " lon_col: str = \"lon\",\n", + " lat_col: str = \"lat\",\n", + " dt: datetime | None = None,\n", + " date_col: str | None = None,\n", + " dt_gap: timedelta | None = None,\n", + " filter_datetime: bool = False,\n", + ") -> pl.DataFrame:\n", + " \"\"\"\n", + " Find observations nearby to a position in space (and optionally time).\n", + "\n", + " Get a frame of all records that are within a maximum distance of the\n", + " provided point.\n", + "\n", + " If filter_datetime is True, then only records from the same datetime will\n", + " be returned. If a specific filter is desired this should be performed\n", + " before calling this function and set filter_datetime to False.\n", + "\n", + " Parameters\n", + " ----------\n", + " lon : float\n", + " The longitude of the position.\n", + " lat : float\n", + " The latitude of the position.\n", + " pool : polars.DataFrame\n", + " The pool of records to search. Can be pre-filtered and filter_datetime\n", + " set to False.\n", + " max_dist : float\n", + " Will return records that have distance to the point <= this value.\n", + " lon_col : str\n", + " Name of the longitude column in the pool DataFrame\n", + " lat_col : str\n", + " Name of the latitude column in the pool DataFrame\n", + " dt : datetime | None\n", + " Datetime of the record. Must be set if filter_datetime is True.\n", + " date_col : str | None\n", + " Name of the datetime column in the pool. Must be set if filter_datetime\n", + " is True.\n", + " dt_gap : timedelta | None\n", + " Allowed time-gap for records. Records that fall between\n", + " dt - dt_gap and dt + dt_gap will be returned. If not set then only\n", + " records at dt will be returned. Applies if filter_datetime is True.\n", + " filter_datetime : bool\n", + " Only return records at the same datetime record as the input value. If\n", + " assessing multiple points with different datetimes, hence calling this\n", + " function frequently it will be more efficient to partition the pool\n", + " first, then set this value to False and only input the subset of data.\n", + "\n", + " Returns\n", + " -------\n", + " polars.DataFrame\n", + " Containing only records from the pool within max_dist of the input\n", + " point, optionally at the same datetime if filter_datetime is True.\n", + " \"\"\"\n", + " required_cols = [lon_col, lat_col]\n", + " check_cols(pool, required_cols, \"pool\")\n", + "\n", + " if filter_datetime:\n", + " if not dt or not date_col:\n", + " raise ValueError(\n", + " \"'dt' and 'date_col' must be provided if 'filter_datetime' \"\n", + " + \"is True\"\n", + " )\n", + " if date_col not in pool.columns:\n", + " raise ValueError(f\"'date_col' value {date_col} not found in pool.\")\n", + " if not dt_gap:\n", + " pool = pool.filter(pl.col(date_col).eq(dt))\n", + " else:\n", + " pool = pool.filter(\n", + " pl.col(date_col).is_between(\n", + " dt - dt_gap, dt + dt_gap, closed=\"both\"\n", + " )\n", + " )\n", + "\n", + " return (\n", + " pool.with_columns(\n", + " [pl.lit(lon).alias(\"_lon\"), pl.lit(lat).alias(\"_lat\")]\n", + " )\n", + " .pipe(\n", + " haversine_df,\n", + " lon_col=lon_col,\n", + " lat_col=lat_col,\n", + " out_colname=\"_dist\",\n", + " lon2_col=\"_lon\",\n", + " lat2_col=\"_lat\",\n", + " )\n", + " .filter(pl.col(\"_dist\").le(max_dist))\n", + " .drop([\"_dist\", \"_lon\", \"_lat\"])\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "8b9279ed-6f89-4423-8833-acd0b365eb7b", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5.33 ms ± 20.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "rec = random.choice(test_recs)\n", + "nearby_ships(lon=rec.lon, lat=rec.lat, dt=rec.datetime, max_dist=dist, dt_gap=dt, date_col=\"datetime\", pool=df, filter_datetime=True)" + ] + }, + { + "cell_type": "markdown", + "id": "d148f129-9d8c-4c46-8f01-3e9c1e93e81a", + "metadata": {}, + "source": [ + "## Verify\n", + "\n", + "Check that records are the same" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "11f3d73a-fbe5-4f27-88d8-d0d687bd0eac", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.52 s, sys: 237 ms, total: 2.75 s\n", + "Wall time: 2.65 s\n" + ] + } + ], + "source": [ + "%%time\n", + "dist = 250\n", + "for _ in range(250):\n", + " rec = Record(*random.choice(df.rows()))\n", + " orig = nearby_ships(lon=rec.lon, lat=rec.lat, dt=rec.datetime, max_dist=dist, dt_gap=dt, date_col=\"datetime\", pool=df, filter_datetime=True)\n", + " tree = otree.nearby_points(rec, dist=dist, t_dist=dt)\n", + " if orig.height > 0:\n", + " if not tree:\n", + " print(rec)\n", + " print(\"NO TREE!\")\n", + " print(f\"{orig = }\")\n", + " else:\n", + " tree = pl.from_records([(r.lon, r.lat, r.datetime, r.uid) for r in tree], orient=\"row\").rename({\"column_0\": \"lon\", \"column_1\": \"lat\", \"column_2\": \"datetime\", \"column_3\": \"uid\"})\n", + " if tree.height != orig.height:\n", + " print(\"Tree and Orig Heights Do Not Match\")\n", + " print(f\"{orig = }\")\n", + " print(f\"{tree = }\")\n", + " else:\n", + " # tree = tree.with_columns(pl.col(\"uid\").str.slice(0, 6))\n", + " if not tree.sort(\"uid\").equals(orig.sort(\"uid\")):\n", + " print(\"Tree and Orig Do Not Match\")\n", + " print(f\"{orig = }\")\n", + " print(f\"{tree = }\")" + ] + }, + { + "cell_type": "markdown", + "id": "1223529e-bfae-4b83-aba7-505d05e588d3", + "metadata": {}, + "source": [ + "## Check -180/180 boundary" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "4c392292-2d9f-4301-afb5-019fde069a1e", + "metadata": {}, + "outputs": [], + "source": [ + "out = otree.nearby_points(Record(179.5, -43.1, datetime(1900, 1, 14, 13)), dist=200, t_dist=timedelta(days=3))\n", + "for o in out:\n", + " print(o)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "GeoSpatialTools", + "language": "python", + "name": "geospatialtools" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 88833e3b11dd29b5e9b368084e42d7206790a92a..19265f566b8b41525282756863b52c3bb2540a6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ description = "Tools for processing geo spatial data." readme = "README.md" license = {file = "LICENSE"} keywords = [ - "spatial", "geospatial", "quadtree", "octtree", + "spatial", "geospatial", "quadtree", "octtree", "nearest neighbour", ] classifiers = [ "Development Status :: 1 - PreAlpha", @@ -38,8 +38,9 @@ classifiers = [ ] [project.optional-dependencies] -extra = [ +notebooks = [ "ipykernel", + "polars" ] test = [ "pytest",