{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bdfa1141-8ae0-499b-8355-927759af69d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import gzip\n",
    "os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n",
    "\n",
    "from datetime import datetime, timedelta\n",
    "from random import choice\n",
    "from string import ascii_letters, digits\n",
    "import random\n",
    "import inspect\n",
    "\n",
    "import polars as pl\n",
    "import numpy as np\n",
    "\n",
    "from GeoSpatialTools import Record, haversine, KDTree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8711862a-6295-43eb-ac51-333fda638ef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def randnum() -> float:\n",
    "    return 2 * (np.random.rand() - 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "72164093-fac1-4dfc-803b-6522cc9a4d62",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_uid(n: int) -> str:\n",
    "    chars = ascii_letters + digits\n",
    "    return \"\".join(random.choice(chars) for _ in range(n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c60b30de-f864-477a-a09a-5f1caa4d9b9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(16000, 2)\n",
      "shape: (5, 2)\n",
      "┌──────┬─────┐\n",
      "│ lon  ┆ lat │\n",
      "│ ---  ┆ --- │\n",
      "│ i64  ┆ i64 │\n",
      "╞══════╪═════╡\n",
      "│ 127  ┆ 21  │\n",
      "│ -148 ┆ 36  │\n",
      "│ -46  ┆ -15 │\n",
      "│ 104  ┆ 89  │\n",
      "│ -57  ┆ -31 │\n",
      "└──────┴─────┘\n"
     ]
    }
   ],
   "source": [
    "N = 16_000\n",
    "lons = pl.int_range(-180, 180, eager=True)\n",
    "lats = pl.int_range(-90, 90, eager=True)\n",
    "dates = pl.datetime_range(datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23), interval=\"1h\", eager=True)\n",
    "\n",
    "lons_use = lons.sample(N, with_replacement=True).alias(\"lon\")\n",
    "lats_use = lats.sample(N, with_replacement=True).alias(\"lat\")\n",
    "# dates_use = dates.sample(N, with_replacement=True).alias(\"datetime\")\n",
    "# uids = pl.Series(\"uid\", [generate_uid(8) for _ in range(N)])\n",
    "\n",
    "df = pl.DataFrame([lons_use, lats_use])\n",
    "print(df.shape)\n",
    "print(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "875f2a67-49fe-476f-add1-b1d76c6cd8f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "records = [Record(**r) for r in df.rows(named=True)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1e883e5a-5086-4c29-aff2-d308874eae16",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 151 ms, sys: 360 ms, total: 511 ms\n",
      "Wall time: 57.3 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "kt = KDTree(records)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "69022ad1-5ec8-4a09-836c-273ef452451f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "203 μs ± 4.56 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())\n",
    "kt.query(test_record)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "28031966-c7d0-4201-a467-37590118e851",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.87 ms ± 188 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())\n",
    "np.argmin([test_record.distance(p) for p in records])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0d10b2ba-57b2-475c-9d01-135363423990",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 17.4 s, sys: 147 ms, total: 17.6 s\n",
      "Wall time: 17.6 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "n_samples = 1000\n",
    "tol = 1e-8\n",
    "test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]\n",
    "kd_res = [kt.query(r) for r in test_records]\n",
    "kd_recs = [_[0][0] for _ in kd_res]\n",
    "kd_dists = [_[1] for _ in kd_res]\n",
    "tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]\n",
    "tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n",
    "assert all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]), \"NOT MATCHING?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a6aa6926-7fd5-4fff-bd20-7bc0305b948d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (0, 8)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>test_lon</th><th>test_lat</th><th>kd_dist</th><th>kd_lon</th><th>kd_lat</th><th>tr_dist</th><th>tr_lon</th><th>tr_lat</th></tr><tr><td>f64</td><td>f64</td><td>f64</td><td>i64</td><td>i64</td><td>f64</td><td>i64</td><td>i64</td></tr></thead><tbody></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (0, 8)\n",
       "┌──────────┬──────────┬─────────┬────────┬────────┬─────────┬────────┬────────┐\n",
       "│ test_lon ┆ test_lat ┆ kd_dist ┆ kd_lon ┆ kd_lat ┆ tr_dist ┆ tr_lon ┆ tr_lat │\n",
       "│ ---      ┆ ---      ┆ ---     ┆ ---    ┆ ---    ┆ ---     ┆ ---    ┆ ---    │\n",
       "│ f64      ┆ f64      ┆ f64     ┆ i64    ┆ i64    ┆ f64     ┆ i64    ┆ i64    │\n",
       "╞══════════╪══════════╪═════════╪════════╪════════╪═════════╪════════╪════════╡\n",
       "└──────────┴──────────┴─────────┴────────┴────────┴─────────┴────────┴────────┘"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_lons = [r.lon for r in test_records]\n",
    "test_lats = [r.lat for r in test_records]\n",
    "\n",
    "kd_lons = [r.lon for r in kd_recs]\n",
    "kd_lats = [r.lat for r in kd_recs]\n",
    "\n",
    "tr_lons = [r.lon for r in tr_recs]\n",
    "tr_lats = [r.lat for r in tr_recs]\n",
    "\n",
    "df = pl.DataFrame({\n",
    "    \"test_lon\": test_lons, \n",
    "    \"test_lat\": test_lats,\n",
    "    \"kd_dist\": kd_dists,\n",
    "    \"kd_lon\": kd_lons,\n",
    "    \"kd_lat\": kd_lats,\n",
    "    \"tr_dist\": tr_dists,\n",
    "    \"tr_lon\": tr_lons,\n",
    "    \"tr_lat\": tr_lats,   \n",
    "}).filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n",
    "df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "GeoSpatialTools",
   "language": "python",
   "name": "geospatialtools"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}