{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "bdfa1141-8ae0-499b-8355-927759af69d1", "metadata": {}, "outputs": [], "source": [ "import os\n", "import gzip\n", "os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n", "\n", "from datetime import datetime, timedelta\n", "from random import choice\n", "from string import ascii_letters, digits\n", "import random\n", "import inspect\n", "\n", "import polars as pl\n", "import numpy as np\n", "\n", "from GeoSpatialTools import Record, haversine, KDTree" ] }, { "cell_type": "code", "execution_count": 2, "id": "8711862a-6295-43eb-ac51-333fda638ef4", "metadata": {}, "outputs": [], "source": [ "def randnum() -> float:\n", " return 2 * (np.random.rand() - 0.5)" ] }, { "cell_type": "code", "execution_count": 3, "id": "72164093-fac1-4dfc-803b-6522cc9a4d62", "metadata": {}, "outputs": [], "source": [ "def generate_uid(n: int) -> str:\n", " chars = ascii_letters + digits\n", " return \"\".join(random.choice(chars) for _ in range(n))" ] }, { "cell_type": "code", "execution_count": 4, "id": "c60b30de-f864-477a-a09a-5f1caa4d9b9a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(16000, 2)\n", "shape: (5, 2)\n", "┌──────┬─────┐\n", "│ lon ┆ lat │\n", "│ --- ┆ --- │\n", "│ i64 ┆ i64 │\n", "╞══════╪═════╡\n", "│ 127 ┆ 21 │\n", "│ -148 ┆ 36 │\n", "│ -46 ┆ -15 │\n", "│ 104 ┆ 89 │\n", "│ -57 ┆ -31 │\n", "└──────┴─────┘\n" ] } ], "source": [ "N = 16_000\n", "lons = pl.int_range(-180, 180, eager=True)\n", "lats = pl.int_range(-90, 90, eager=True)\n", "dates = pl.datetime_range(datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23), interval=\"1h\", eager=True)\n", "\n", "lons_use = lons.sample(N, with_replacement=True).alias(\"lon\")\n", "lats_use = lats.sample(N, with_replacement=True).alias(\"lat\")\n", "# dates_use = dates.sample(N, with_replacement=True).alias(\"datetime\")\n", "# uids = pl.Series(\"uid\", [generate_uid(8) for _ in range(N)])\n", "\n", "df = pl.DataFrame([lons_use, lats_use])\n", "print(df.shape)\n", "print(df.head())" ] }, { "cell_type": "code", "execution_count": 5, "id": "875f2a67-49fe-476f-add1-b1d76c6cd8f9", "metadata": {}, "outputs": [], "source": [ "records = [Record(**r) for r in df.rows(named=True)]" ] }, { "cell_type": "code", "execution_count": 6, "id": "1e883e5a-5086-4c29-aff2-d308874eae16", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 151 ms, sys: 360 ms, total: 511 ms\n", "Wall time: 57.3 ms\n" ] } ], "source": [ "%%time\n", "kt = KDTree(records)" ] }, { "cell_type": "code", "execution_count": 7, "id": "69022ad1-5ec8-4a09-836c-273ef452451f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "203 μs ± 4.56 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())\n", "kt.query(test_record)" ] }, { "cell_type": "code", "execution_count": 8, "id": "28031966-c7d0-4201-a467-37590118e851", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8.87 ms ± 188 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", "test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())\n", "np.argmin([test_record.distance(p) for p in records])" ] }, { "cell_type": "code", "execution_count": 9, "id": "0d10b2ba-57b2-475c-9d01-135363423990", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 17.4 s, sys: 147 ms, total: 17.6 s\n", "Wall time: 17.6 s\n" ] } ], "source": [ "%%time\n", "n_samples = 1000\n", "tol = 1e-8\n", "test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]\n", "kd_res = [kt.query(r) for r in test_records]\n", "kd_recs = [_[0][0] for _ in kd_res]\n", "kd_dists = [_[1] for _ in kd_res]\n", "tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]\n", "tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n", "assert all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]), \"NOT MATCHING?\"" ] }, { "cell_type": "code", "execution_count": 10, "id": "a6aa6926-7fd5-4fff-bd20-7bc0305b948d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
test_lon | test_lat | kd_dist | kd_lon | kd_lat | tr_dist | tr_lon | tr_lat |
---|---|---|---|---|---|---|---|
f64 | f64 | f64 | i64 | i64 | f64 | i64 | i64 |