Commit 5cb3030c authored by Joseph Siddons's avatar Joseph Siddons
Browse files

docs: apply ruff fixes to notebooks, add extra markdown elements, re-run

parent c15fe67a
{
"cells": [
{
"cell_type": "markdown",
"id": "f7143f08-1d06-4e94-bbf6-ef35ddd11556",
"metadata": {},
"source": [
"# KDTree\n",
"\n",
"Testing the time to look-up nearby records with the `KDTree` implementation. Note that this implementation is actually a `2DTree` since it can only compute a valid distance comparison between longitude and latitude positions.\n",
"\n",
"The `KDTree` object is used for finding the closest neighbour to a position, in this implementation we use the Haversine distance to compare positions."
]
},
{
"cell_type": "code",
"execution_count": 1,
......@@ -8,11 +20,9 @@
"outputs": [],
"source": [
"import os\n",
"import gzip\n",
"os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n",
"\n",
"from datetime import datetime, timedelta\n",
"from random import choice\n",
"from datetime import datetime\n",
"from string import ascii_letters, digits\n",
"import random\n",
"import inspect\n",
......@@ -20,7 +30,17 @@
"import polars as pl\n",
"import numpy as np\n",
"\n",
"from GeoSpatialTools import Record, haversine, KDTree"
"from GeoSpatialTools import Record, KDTree"
]
},
{
"cell_type": "markdown",
"id": "ec6c6e7f-8eee-47ea-a5e9-12537bb3412d",
"metadata": {},
"source": [
"## Set-up functions\n",
"\n",
"Used for generating data, or for comparisons by doing brute-force approach."
]
},
{
......@@ -31,6 +51,7 @@
"outputs": [],
"source": [
"def randnum() -> float:\n",
" \"\"\"Get a random number between -1 and 1\"\"\"\n",
" return 2 * (np.random.rand() - 0.5)"
]
},
......@@ -42,6 +63,7 @@
"outputs": [],
"source": [
"def generate_uid(n: int) -> str:\n",
" \"\"\"Generates a psuedo uid by randomly selecting from characters\"\"\"\n",
" chars = ascii_letters + digits\n",
" return \"\".join(random.choice(chars) for _ in range(n))"
]
......@@ -49,6 +71,179 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "9e647ecd-abdc-46a0-8261-aa081fda2e1d",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def check_cols(\n",
" df: pl.DataFrame | pl.LazyFrame,\n",
" cols: list[str],\n",
" var_name: str = \"dataframe\",\n",
") -> None:\n",
" \"\"\"\n",
" Check that a dataframe contains a list of columns. Raises an error if not.\n",
"\n",
" Parameters\n",
" ----------\n",
" df : polars Frame\n",
" Dataframe to check\n",
" cols : list[str]\n",
" Requried columns\n",
" var_name : str\n",
" Name of the Frame - used for displaying in any error.\n",
" \"\"\"\n",
" calling_func = inspect.stack()[1][3]\n",
" if isinstance(df, pl.DataFrame):\n",
" have_cols = df.columns\n",
" elif isinstance(df, pl.LazyFrame):\n",
" have_cols = df.collect_schema().names()\n",
" else:\n",
" raise TypeError(\"Input Frame is not a polars Frame\")\n",
"\n",
" cols_in_frame = intersect(cols, have_cols)\n",
" missing = [c for c in cols if c not in cols_in_frame]\n",
"\n",
" if len(missing) > 0:\n",
" err_str = f\"({calling_func}) - {var_name} missing required columns. \"\n",
" err_str += f'Require: {\", \".join(cols)}. '\n",
" err_str += f'Missing: {\", \".join(missing)}.'\n",
" raise ValueError(err_str)\n",
"\n",
" return\n",
"\n",
"\n",
"def haversine_df(\n",
" df: pl.DataFrame | pl.LazyFrame,\n",
" lon: float,\n",
" lat: float,\n",
" R: float = 6371,\n",
" lon_col: str = \"lon\",\n",
" lat_col: str = \"lat\",\n",
") -> pl.DataFrame | pl.LazyFrame:\n",
" \"\"\"\n",
" Compute haversine distance on earth surface between lon-lat positions\n",
" in a polars DataFrame and a lon-lat position.\n",
"\n",
" Parameters\n",
" ----------\n",
" df : polars.DataFrame\n",
" The data, containing required columns:\n",
" * lon_col\n",
" * lat_col\n",
" * date_var\n",
" lon : float\n",
" The longitude of the position.\n",
" lat : float\n",
" The latitude of the position.\n",
" R : float\n",
" Radius of earth in km\n",
" lon_col : str\n",
" Name of the longitude column\n",
" lat_col : str\n",
" Name of the latitude column\n",
"\n",
" Returns\n",
" -------\n",
" polars.DataFrame\n",
" With additional column specifying distances between consecutive points\n",
" in the same units as 'R'. With colname defined by 'out_colname'.\n",
" \"\"\"\n",
" required_cols = [lon_col, lat_col]\n",
"\n",
" check_cols(df, required_cols, \"df\")\n",
" return (\n",
" df.with_columns(\n",
" [\n",
" pl.col(lat_col).radians().alias(\"_lat0\"),\n",
" pl.lit(lat).radians().alias(\"_lat1\"),\n",
" (pl.col(lon_col) - lon).radians().alias(\"_dlon\"),\n",
" (pl.col(lat_col) - lat).radians().alias(\"_dlat\"),\n",
" ]\n",
" )\n",
" .with_columns(\n",
" (\n",
" (pl.col(\"_dlat\") / 2).sin().pow(2)\n",
" + pl.col(\"_lat0\").cos()\n",
" * pl.col(\"_lat1\").cos()\n",
" * (pl.col(\"_dlon\") / 2).sin().pow(2)\n",
" ).alias(\"_a\")\n",
" )\n",
" .with_columns(\n",
" (2 * R * (pl.col(\"_a\").sqrt().arcsin()))\n",
" .round(2)\n",
" .alias(\"_dist\")\n",
" )\n",
" .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n",
" )\n",
"\n",
"\n",
"def intersect(a, b) -> set:\n",
" \"\"\"Intersection of a and b, items in both a and b\"\"\"\n",
" return set(a) & set(b)\n",
"\n",
"\n",
"def nearest_ship(\n",
" lon: float,\n",
" lat: float,\n",
" df: pl.DataFrame,\n",
" lon_col: str = \"lon\",\n",
" lat_col: str = \"lat\",\n",
") -> pl.DataFrame:\n",
" \"\"\"\n",
" Find the observation nearest to a position in space.\n",
"\n",
" Get a frame with only the records that is closest to the input point.\n",
"\n",
" Parameters\n",
" ----------\n",
" lon : float\n",
" The longitude of the position.\n",
" lat : float\n",
" The latitude of the position.\n",
" df : polars.DataFrame\n",
" The pool of records to search. Can be pre-filtered and filter_datetime\n",
" set to False.\n",
" lon_col : str\n",
" Name of the longitude column in the pool DataFrame\n",
" lat_col : str\n",
" Name of the latitude column in the pool DataFrame\n",
"\n",
" Returns\n",
" -------\n",
" polars.DataFrame\n",
" Containing only records from the pool within max_dist of the input\n",
" point, optionally at the same datetime if filter_datetime is True.\n",
" \"\"\"\n",
" required_cols = [lon_col, lat_col]\n",
" check_cols(df, required_cols, \"df\")\n",
"\n",
" return (\n",
" df\n",
" .pipe(\n",
" haversine_df,\n",
" lon=lon,\n",
" lat=lat,\n",
" lon_col=lon_col,\n",
" lat_col=lat_col,\n",
" )\n",
" .filter(pl.col(\"_dist\").eq(pl.col(\"_dist\").min()))\n",
" .drop([\"_dist\"])\n",
" )\n"
]
},
{
"cell_type": "markdown",
"id": "287bdc1d-1ecf-4c59-af95-d2dc639c6894",
"metadata": {},
"source": [
"## Initialise random data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c60b30de-f864-477a-a09a-5f1caa4d9b9a",
"metadata": {},
"outputs": [
......@@ -63,11 +258,11 @@
"│ --- ┆ --- │\n",
"│ i64 ┆ i64 │\n",
"╞══════╪═════╡\n",
"│ 12721 │\n",
"│ -14836 │\n",
"│ -46-15 │\n",
"│ 104 ┆ 89 │\n",
"│ -57-31 │\n",
"│ 62 -29 │\n",
"│ 146 1 │\n",
"│ 10460 │\n",
"│ -162 ┆ -66 │\n",
"│ 72 69 │\n",
"└──────┴─────┘\n"
]
}
......@@ -76,7 +271,12 @@
"N = 16_000\n",
"lons = pl.int_range(-180, 180, eager=True)\n",
"lats = pl.int_range(-90, 90, eager=True)\n",
"dates = pl.datetime_range(datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23), interval=\"1h\", eager=True)\n",
"dates = pl.datetime_range(\n",
" datetime(1900, 1, 1, 0),\n",
" datetime(1900, 1, 31, 23),\n",
" interval=\"1h\",\n",
" eager=True,\n",
")\n",
"\n",
"lons_use = lons.sample(N, with_replacement=True).alias(\"lon\")\n",
"lats_use = lats.sample(N, with_replacement=True).alias(\"lat\")\n",
......@@ -90,7 +290,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "875f2a67-49fe-476f-add1-b1d76c6cd8f9",
"metadata": {},
"outputs": [],
......@@ -98,9 +298,19 @@
"records = [Record(**r) for r in df.rows(named=True)]"
]
},
{
"cell_type": "markdown",
"id": "bd83330b-ef2c-478e-9a7b-820454d198bb",
"metadata": {},
"source": [
"## Intialise the `KDTree`\n",
"\n",
"There is an overhead to constructing a `KDTree` object, so performance improvement is only for multiple comparisons."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "1e883e5a-5086-4c29-aff2-d308874eae16",
"metadata": {},
"outputs": [
......@@ -108,8 +318,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 151 ms, sys: 360 ms, total: 511 ms\n",
"Wall time: 57.3 ms\n"
"CPU times: user 32.7 ms, sys: 1.4 ms, total: 34.1 ms\n",
"Wall time: 33.4 ms\n"
]
}
],
......@@ -118,9 +328,30 @@
"kt = KDTree(records)"
]
},
{
"cell_type": "markdown",
"id": "0a37ef06-2691-4e01-96a9-1c1ecd582599",
"metadata": {},
"source": [
"## Compare with brute force approach"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "365bbf30-7a93-438d-92b2-a3471f1e9249",
"metadata": {},
"outputs": [],
"source": [
"test_record = Record(\n",
" random.choice(range(-179, 180)) + randnum(),\n",
" random.choice(range(-89, 90)) + randnum(),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "69022ad1-5ec8-4a09-836c-273ef452451f",
"metadata": {},
"outputs": [
......@@ -128,19 +359,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"203 μs ± 4.56 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
"130 μs ± 847 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())\n",
"kt.query(test_record)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"id": "28031966-c7d0-4201-a467-37590118e851",
"metadata": {},
"outputs": [
......@@ -148,19 +378,45 @@
"name": "stdout",
"output_type": "stream",
"text": [
"8.87 ms ± 188 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"8.34 ms ± 83.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())\n",
"np.argmin([test_record.distance(p) for p in records])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 11,
"id": "09e0f923-ca49-47bf-8643-e0b3a6d0467c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8.28 ms ± 105 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"nearest_ship(lon=test_record.lon, lat=test_record.lat, df=df)"
]
},
{
"cell_type": "markdown",
"id": "f0359950-942d-45ea-8676-b22c8ce9e296",
"metadata": {},
"source": [
"## Verify that results are correct"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0d10b2ba-57b2-475c-9d01-135363423990",
"metadata": {},
"outputs": [
......@@ -168,8 +424,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 17.4 s, sys: 147 ms, total: 17.6 s\n",
"Wall time: 17.6 s\n"
"CPU times: user 16.9 s, sys: 144 ms, total: 17 s\n",
"Wall time: 17 s\n"
]
}
],
......@@ -177,18 +433,28 @@
"%%time\n",
"n_samples = 1000\n",
"tol = 1e-8\n",
"test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]\n",
"test_records = [\n",
" Record(\n",
" random.choice(range(-179, 180)) + randnum(),\n",
" random.choice(range(-89, 90)) + randnum()\n",
" ) for _ in range(n_samples)\n",
"]\n",
"kd_res = [kt.query(r) for r in test_records]\n",
"kd_recs = [_[0][0] for _ in kd_res]\n",
"kd_dists = [_[1] for _ in kd_res]\n",
"tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]\n",
"tr_recs = [\n",
" records[np.argmin([r.distance(p) for p in records])]\n",
" for r in test_records\n",
"]\n",
"tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n",
"assert all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]), \"NOT MATCHING?\""
"\n",
"if not all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]):\n",
" raise ValueError(\"NOT MATCHING?\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 13,
"id": "a6aa6926-7fd5-4fff-bd20-7bc0305b948d",
"metadata": {},
"outputs": [
......@@ -214,7 +480,7 @@
"└──────────┴──────────┴─────────┴────────┴────────┴─────────┴────────┴────────┘"
]
},
"execution_count": 10,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
......@@ -230,14 +496,14 @@
"tr_lats = [r.lat for r in tr_recs]\n",
"\n",
"df = pl.DataFrame({\n",
" \"test_lon\": test_lons, \n",
" \"test_lon\": test_lons,\n",
" \"test_lat\": test_lats,\n",
" \"kd_dist\": kd_dists,\n",
" \"kd_lon\": kd_lons,\n",
" \"kd_lat\": kd_lats,\n",
" \"tr_dist\": tr_dists,\n",
" \"tr_lon\": tr_lons,\n",
" \"tr_lat\": tr_lats, \n",
" \"tr_lat\": tr_lats,\n",
"}).filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n",
"df"
]
......@@ -245,7 +511,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "GeoSpatialTools",
"display_name": "geospatialtools",
"language": "python",
"name": "geospatialtools"
},
......@@ -259,7 +525,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.11.11"
}
},
"nbformat": 4,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment