{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c70ce35d-6112-4c12-9387-9c788c84a8e9",
   "metadata": {},
   "source": [
    "## OctTree!\n",
    "\n",
    "Testing the time to look-up nearby records with the PyCOADS OctTree implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c0956916-f50a-444d-a5b6-f06d3fb9b44d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import gzip\n",
    "os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n",
    "\n",
    "from datetime import datetime, timedelta\n",
    "from random import choice\n",
    "from string import ascii_letters, digits\n",
    "import random\n",
    "import inspect\n",
    "\n",
    "import polars as pl\n",
    "import numpy as np\n",
    "\n",
    "from GeoSpatialTools.octtree import OctTree, SpaceTimeRecord as Record, SpaceTimeRectangle as Rectangle"
   ]
  },
  {
   "cell_type": "raw",
   "id": "99295bad-0db3-444b-8d38-acc7875cc0f0",
   "metadata": {},
   "source": [
    "## Generate Data\n",
    "\n",
    "16,000 rows of data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d8f1e5e1-513c-4bdf-a9f9-cef9562a7cb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_uid(n: int) -> str:\n",
    "    chars = ascii_letters + digits\n",
    "    return \"\".join(random.choice(chars) for _ in range(n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "986d9cc5-e610-449a-9ee7-e281b7558ca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 16_000\n",
    "lons = pl.int_range(-180, 180, eager=True)\n",
    "lats = pl.int_range(-90, 90, eager=True)\n",
    "dates = pl.datetime_range(datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23), interval=\"1h\", eager=True)\n",
    "\n",
    "lons_use = lons.sample(N, with_replacement=True).alias(\"lon\")\n",
    "lats_use = lats.sample(N, with_replacement=True).alias(\"lat\")\n",
    "dates_use = dates.sample(N, with_replacement=True).alias(\"datetime\")\n",
    "uids = pl.Series(\"uid\", [generate_uid(8) for _ in range(N)])\n",
    "\n",
    "df = pl.DataFrame([lons_use, lats_use, dates_use, uids]).unique()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "237096f1-093e-49f0-9a9a-2bec5231726f",
   "metadata": {},
   "source": [
    "## Add extra rows\n",
    "\n",
    "For testing larger datasets. Uncomment to use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0b8fd425-8a90-4f76-91b7-60df48aa98e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# _df = df.clone()\n",
    "# for i in range(100):\n",
    "#     df2 = pl.DataFrame([\n",
    "#         _df[\"lon\"].shuffle(),\n",
    "#         _df[\"lat\"].shuffle(),\n",
    "#         _df[\"datetime\"].shuffle(),\n",
    "#         _df[\"uid\"].shuffle(),\n",
    "#     ]).with_columns(pl.concat_str([pl.col(\"uid\"), pl.lit(f\"{i:03d}\")]).alias(\"uid\"))\n",
    "#     df = df.vstack(df2)\n",
    "# df.shape\n",
    "# df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7bd16e0-96a6-426b-b00a-7c3b8a2aaddd",
   "metadata": {},
   "source": [
    "## Intialise the OctTree Object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af06a976-ff52-49e0-a886-91bcbe540ffe",
   "metadata": {},
   "outputs": [],
   "source": [
    "otree = OctTree(Rectangle(0, 0, datetime(1900, 1, 16), 360, 180, timedelta(days=32)), capacity = 10, max_depth = 25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2ba99b37-787c-4862-8075-a7596208c60e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 186 ms, sys: 191 ms, total: 377 ms\n",
      "Wall time: 118 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "for r in df.rows():\n",
    "    otree.insert(Record(*r))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "59d38446-f7d2-4eec-bba3-c39bd7279623",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OctTree:\n",
      "- boundary: SpaceTimeRectangle(x = 0, y = 0, w = 360, h = 180, t = 1900-01-16 00:00:00, dt = 32 days, 0:00:00)\n",
      "- capacity: 10\n",
      "- depth: 0\n",
      "- max_depth: 25\n",
      "- contents:\n",
      "- number of elements: 10\n",
      "  * Record(x = 43, y = -68, datetime = 1900-01-08 13:00:00, uid = OBiqSYcn)\n",
      "  * Record(x = 97, y = -47, datetime = 1900-01-02 14:00:00, uid = w589k3Oe)\n",
      "  * Record(x = -68, y = 44, datetime = 1900-01-30 11:00:00, uid = XAaA7McU)\n",
      "  * Record(x = -170, y = 77, datetime = 1900-01-19 09:00:00, uid = x6eLi65N)\n",
      "  * Record(x = -2, y = 7, datetime = 1900-01-12 09:00:00, uid = CjB2Pglt)\n",
      "  * Record(x = -175, y = 65, datetime = 1900-01-15 01:00:00, uid = bTB9DkDI)\n",
      "  * Record(x = 8, y = 83, datetime = 1900-01-04 10:00:00, uid = aYCKIBl9)\n",
      "  * Record(x = 20, y = 60, datetime = 1900-01-24 16:00:00, uid = 8GsD19WF)\n",
      "  * Record(x = 161, y = 40, datetime = 1900-01-24 20:00:00, uid = FIfAABuC)\n",
      "  * Record(x = -69, y = -9, datetime = 1900-01-11 11:00:00, uid = uTcS5D4e)\n",
      "- with children:\n",
      "    OctTree:\n",
      "    - boundary: SpaceTimeRectangle(x = -90.0, y = 45.0, w = 180.0, h = 90.0, t = 1900-01-08 00:00:00, dt = 16 days, 0:00:00)\n",
      "    - capacity: 10\n",
      "    - depth: 1\n",
      "    - max_depth: 25\n",
      "    - contents:\n",
      "    - number of elements: 10\n",
      "      * Record(x = -156, y = 57, datetime = 1900-01-08 10:00:00, uid = aFheRU2n)\n",
      "      * Record(x = -100, y = 61, datetime = 1900-01-15 09:00:00, uid = Sa1iavle)\n",
      "      * Record(x = -168, y = 88, datetime = 1900-01-03 07:00:00, uid = IlYKGW0N)\n",
      "      * Record(x = -80, y = 50, datetime = 1900-01-05 09:00:00, uid = Rg3GHM4d)\n",
      "      * Record(x = -92, y = 39, datetime = 1900-01-15 06:00:00, uid = u804YMFB)\n",
      "      * Record(x = -119, y = 60, datetime = 1900-01-12 22:00:00, uid = vdEPjkib)\n",
      "      * Record(x = -160, y = 79, datetime = 1900-01-06 08:00:00, uid = QmrPEL6h)\n",
      "      * Record(x = -95, y = 21, datetime = 1900-01-09 04:00:00, uid = hfjTKSCH)\n",
      "      * Record(x = -93, y = 61, datetime = 1900-01-09 20:00:00, uid = SzIrja9S)\n",
      "      * Record(x = -149, y = 34, datetime = 1900-01-05 05:00:00, uid = b02MxQjV)\n",
      "    - with children:\n",
      "        OctTree:\n",
      "        - boundary: SpaceTimeRectangle(x = -135.0, y = 67.5, w = 90.0, h = 45.0, t = 1900-01-04 00:00:00, dt = 8 days, 0:00:00)\n",
      "        - capacity: 10\n",
      "        - depth: 2\n",
      "        - max_depth: 25\n",
      "        - contents:\n",
      "        - number of elements: 10\n",
      "          * Record(x = -134, y = 79, datetime = 1900-01-05 14:00:00, uid = 7Q0FKGMk)\n",
      "          * Record(x = -90, y = 53, datetime = 1900-01-05 03:00:00, uid = LLx7iz2v)\n",
      "          * Record(x = -176, y = 50, datetime = 1900-01-06 20:00:00, uid = x6K5DlTl)\n",
      "          * Record(x = -141, y = 52, datetime = 1900-01-02 15:00:00, uid = xTpGPaEy)\n",
      "          * Record(x = -116, y = 68, datetime = 1900-01-05 16:00:00, uid = eECSkpdU)\n",
      "          * Record(x = -138, y = 63, datetime = 1900-01-05 02:00:00, uid = Ftf9uhH3)\n",
      "          * Record(x = -173, y = 71, datetime = 1900-01-03 03:00:00, uid = mu3vwHM5)\n",
      "          * Record(x = -148, y = 49, datetime = 1900-01-05 15:00:00, uid = 8DFDI3CJ)\n",
      "          * Record(x = -157, y = 63, datetime = 1900-01-06 19:00:00, uid = mVqLntgh)\n",
      "          * Record(x = -154, y = 45, datetime = 1900-01-07 11:00:00, uid = 1UoA1NNC)\n",
      "        - with children:\n",
      "            OctTree:\n",
      "            - boundary: SpaceTimeRectangle(x = -157.5, y = 78.75, w = 45.0, h = 22.5, t = 1900-01-02 00:00:00, dt = 4 days, 0:00:00)\n",
      "            - capacity: 10\n",
      "            - depth: 3\n",
      "            - max_depth: 25\n",
      "            - contents:\n",
      "            - number of elements: 10\n",
      "              * Record(x = -147, y = 83, datetime = 1900-01-01 18:00:00, uid = WaO5R7fy)\n",
      "              * Record(x = -136, y = 72, datetime = 1900-01-02 03:00:00, uid = OWaMqULr)\n",
      "              * Record(x = -176, y = 79, datetime = 1900-01-02 06:00:00, uid = NTjvqz2c)\n",
      "              * Record(x = -152, y = 72, datetime = 1900-01-03 18:00:00, uid = 7rtQIGtn)\n",
      "              * Record(x = -162, y = 78, datetime = 1900-01-02 04:00:00, uid = Wi9RsOIX)\n",
      "              * Record(x = -136, y = 79, datetime = 1900-01-01 11:00:00, uid = hSltzeuH)\n",
      "              * Record(x = -176, y = 89, datetime = 1900-01-02 09:00:00, uid = cOLgAely)\n",
      "              * Record(x = -141, y = 75, datetime = 1900-01-03 23:00:00, uid = gH755dC3)\n",
      "              * Record(x = -158, y = 72, datetime = 1900-01-02 23:00:00, uid = NUmMfw9K)\n",
      "              * Record(x = -168, y = 72, datetime = 1900-01-02 01:00:00, uid = ZFcsxYG4)\n",
      "            - with children:\n",
      "                OctTree:\n",
      "                - boundary: SpaceTimeRectangle(x = -168.75, y = 84.375, w = 22.5, h = 11.25, t = 1900-01-01 00:00:00, dt = 2 days, 0:00:00)\n",
      "                - capacity: 10\n",
      "                - depth: 4\n",
      "                - max_depth: 25\n",
      "                - contents:\n",
      "                - number of elements: 6\n",
      "                  * Record(x = -158, y = 86, datetime = 1900-01-01 15:00:00, uid = DOD5jT2l)\n",
      "                  * Record(x = -165, y = 88, datetime = 1900-01-01 13:00:00, uid = kdGlzz41)\n",
      "                  * Record(x = -173, y = 82, datetime = 1900-01-01 04:00:00, uid = aWBwIP4U)\n",
      "                  * Record(x = -180, y = 89, datetime = 1900-01-01 22:00:00, uid = HOxbaCm8)\n",
      "                  * Record(x = -165, y = 81, datetime = 1900-01-01 16:00:00, uid = JtRn9y9e)\n",
      "                  * Record(x = -164, y = 84, datetime = 1900-01-01 03:00:00, uid = vELpx1ij)\n",
      "                OctTree:\n",
      "                - boundary: SpaceTimeRectangle(x = -146.25, y = 84.375, w = 22.5, h = 11.25, t = 1900-01-01 00:00:00, dt = 2 days, 0:00:00)\n",
      "                - capacity: 10\n",
      "                - depth: 4\n",
      "                - max_depth: 25\n",
      "                - contents:\n",
      "                - number of elements: 1\n",
      "                  * Record(x = -157, y = 84, datetime = 1900-01-01 17:00:00, uid = 6DlgVOXg)\n",
      "                OctTree:\n",
      "                - boundary: SpaceTimeRectangle(x = -168.75, y = 73.125, w = 22.5, h = 11.25, t = 1900-01-01 00:00:00, dt = 2 days, 0:00:00)\n",
      "                - capacity: 10\n",
      "                - depth: 4\n",
      "                - max_depth: 25\n",
      "                - contents:\n",
      "                - number of elements: 2\n"
     ]
    }
   ],
   "source": [
    "s = str(otree)\n",
    "print(\"\\n\".join(s.split(\"\\n\")[:100]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b02c2ea-6566-47c2-97e0-43d8b18e0713",
   "metadata": {},
   "source": [
    "## Time Execution\n",
    "\n",
    "Testing the identification of nearby points against the original full search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "094b588c-e938-4838-9719-1defdfff74fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "dts = pl.datetime_range(datetime(1900, 1, 1), datetime(1900, 2, 1), interval=\"1h\", eager=True, closed=\"left\")\n",
    "N = dts.len()\n",
    "lons = 180 - 360 * np.random.rand(N)\n",
    "lats = 90 -  180 * np.random.rand(N)\n",
    "test_df = pl.DataFrame({\"lon\": lons, \"lat\": lats, \"datetime\": dts})\n",
    "test_recs = [Record(*r) for r in test_df.rows()]\n",
    "dt = timedelta(days = 1)\n",
    "dist = 350"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "66a48b86-d449-45d2-9837-2b3e07f5563d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "206 μs ± 3.36 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "otree.nearby_points(random.choice(test_recs), dist=dist, t_dist=dt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "972d4a16-39fd-4f80-8592-1c5d5cabf5be",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def check_cols(\n",
    "    df: pl.DataFrame | pl.LazyFrame,\n",
    "    cols: list[str],\n",
    "    var_name: str = \"dataframe\",\n",
    ") -> None:\n",
    "    \"\"\"\n",
    "    Check that a dataframe contains a list of columns. Raises an error if not.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    df : polars Frame\n",
    "        Dataframe to check\n",
    "    cols : list[str]\n",
    "        Requried columns\n",
    "    var_name : str\n",
    "        Name of the Frame - used for displaying in any error.\n",
    "    \"\"\"\n",
    "    calling_func = inspect.stack()[1][3]\n",
    "    if isinstance(df, pl.DataFrame):\n",
    "        have_cols = df.columns\n",
    "    elif isinstance(df, pl.LazyFrame):\n",
    "        have_cols = df.collect_schema().names()\n",
    "    else:\n",
    "        raise TypeError(\"Input Frame is not a polars Frame\")\n",
    "\n",
    "    cols_in_frame = intersect(cols, have_cols)\n",
    "    missing = [c for c in cols if c not in cols_in_frame]\n",
    "\n",
    "    if len(missing) > 0:\n",
    "        err_str = f\"({calling_func}) - {var_name} missing required columns. \"\n",
    "        err_str += f'Require: {\", \".join(cols)}. '\n",
    "        err_str += f'Missing: {\", \".join(missing)}.'\n",
    "        logging.error(err_str)\n",
    "        raise ValueError(err_str)\n",
    "\n",
    "    return\n",
    "\n",
    "\n",
    "def haversine_df(\n",
    "    df: pl.DataFrame | pl.LazyFrame,\n",
    "    date_var: str = \"datetime\",\n",
    "    R: float = 6371,\n",
    "    reverse: bool = False,\n",
    "    out_colname: str = \"dist\",\n",
    "    lon_col: str = \"lon\",\n",
    "    lat_col: str = \"lat\",\n",
    "    lon2_col: str | None = None,\n",
    "    lat2_col: str | None = None,\n",
    "    sorted: bool = False,\n",
    "    rev_prefix: str = \"rev_\",\n",
    ") -> pl.DataFrame | pl.LazyFrame:\n",
    "    \"\"\"\n",
    "    Compute haversine distance on earth surface between lon-lat positions.\n",
    "\n",
    "    If only 'lon_col' and 'lat_col' are specified then this computes the\n",
    "    distance between consecutive points. If a second set of positions is\n",
    "    included via the optional 'lon2_col' and 'lat2_col' arguments then the\n",
    "    distances between the columns are computed.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    df : polars.DataFrame\n",
    "        The data, containing required columns:\n",
    "            * lon_col\n",
    "            * lat_col\n",
    "            * date_var\n",
    "    date_var : str\n",
    "        Name of the datetime column on which to sort the positions\n",
    "    R : float\n",
    "        Radius of earth in km\n",
    "    reverse : bool\n",
    "        Compute distances in reverse\n",
    "    out_colname : str\n",
    "        Name of the output column to store distances. Prefixed with 'rev_' if\n",
    "        reverse is True\n",
    "    lon_col : str\n",
    "        Name of the longitude column\n",
    "    lat_col : str\n",
    "        Name of the latitude column\n",
    "    lon2_col : str\n",
    "        Name of the 2nd longitude column if present\n",
    "    lat2_col : str\n",
    "        Name of the 2nd latitude column if present\n",
    "    sorted : bool\n",
    "        Compute distances assuming that the frame is already sorted\n",
    "    rev_prefix : str\n",
    "        Prefix to use for colnames if reverse is True\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    polars.DataFrame\n",
    "        With additional column specifying distances between consecutive points\n",
    "        in the same units as 'R'. With colname defined by 'out_colname'.\n",
    "    \"\"\"\n",
    "    required_cols = [lon_col, lat_col]\n",
    "\n",
    "    if lon2_col is not None and lat2_col is not None:\n",
    "        required_cols += [lon2_col, lat2_col]\n",
    "        check_cols(df, required_cols, \"df\")\n",
    "        return (\n",
    "            df.with_columns(\n",
    "                [\n",
    "                    pl.col(lat_col).radians().alias(\"_lat0\"),\n",
    "                    pl.col(lat2_col).radians().alias(\"_lat1\"),\n",
    "                    (pl.col(lon_col) - pl.col(lon2_col))\n",
    "                    .radians()\n",
    "                    .alias(\"_dlon\"),\n",
    "                    (pl.col(lat_col) - pl.col(lat2_col))\n",
    "                    .radians()\n",
    "                    .alias(\"_dlat\"),\n",
    "                ]\n",
    "            )\n",
    "            .with_columns(\n",
    "                (\n",
    "                    (pl.col(\"_dlat\") / 2).sin().pow(2)\n",
    "                    + pl.col(\"_lat0\").cos()\n",
    "                    * pl.col(\"_lat1\").cos()\n",
    "                    * (pl.col(\"_dlon\") / 2).sin().pow(2)\n",
    "                ).alias(\"_a\")\n",
    "            )\n",
    "            .with_columns(\n",
    "                (2 * R * (pl.col(\"_a\").sqrt().arcsin()))\n",
    "                .round(2)\n",
    "                .alias(out_colname)\n",
    "            )\n",
    "            .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n",
    "        )\n",
    "\n",
    "    if lon2_col is not None or lat2_col is not None:\n",
    "        logging.warning(\n",
    "            \"(haversine_df) 2nd position incorrectly specified. \"\n",
    "            + \"Calculating consecutive distances.\"\n",
    "        )\n",
    "\n",
    "    required_cols += [date_var]\n",
    "    check_cols(df, required_cols, \"df\")\n",
    "    if reverse:\n",
    "        out_colname = rev_prefix + out_colname\n",
    "    if not sorted:\n",
    "        df = df.sort(date_var, descending=reverse)\n",
    "    return (\n",
    "        df.with_columns(\n",
    "            [\n",
    "                pl.col(lat_col).radians().alias(\"_lat0\"),\n",
    "                pl.col(lat_col).shift(n=-1).radians().alias(\"_lat1\"),\n",
    "                (pl.col(lon_col).shift(n=-1) - pl.col(lon_col))\n",
    "                .radians()\n",
    "                .alias(\"_dlon\"),\n",
    "                (pl.col(lat_col).shift(n=-1) - pl.col(lat_col))\n",
    "                .radians()\n",
    "                .alias(\"_dlat\"),\n",
    "            ]\n",
    "        )\n",
    "        .with_columns(\n",
    "            (\n",
    "                (pl.col(\"_dlat\") / 2).sin().pow(2)\n",
    "                + pl.col(\"_lat0\").cos()\n",
    "                * pl.col(\"_lat1\").cos()\n",
    "                * (pl.col(\"_dlon\") / 2).sin().pow(2)\n",
    "            ).alias(\"_a\")\n",
    "        )\n",
    "        .with_columns(\n",
    "            (2 * R * (pl.col(\"_a\").sqrt().arcsin()))\n",
    "            .round(2)\n",
    "            .fill_null(strategy=\"forward\")\n",
    "            .alias(out_colname)\n",
    "        )\n",
    "        .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n",
    "    )\n",
    "\n",
    "def intersect(a, b) -> set:\n",
    "    return set(a) & set(b)\n",
    "\n",
    "def nearby_ships(\n",
    "    lon: float,\n",
    "    lat: float,\n",
    "    pool: pl.DataFrame,\n",
    "    max_dist: float,\n",
    "    lon_col: str = \"lon\",\n",
    "    lat_col: str = \"lat\",\n",
    "    dt: datetime | None = None,\n",
    "    date_col: str | None = None,\n",
    "    dt_gap: timedelta | None = None,\n",
    "    filter_datetime: bool = False,\n",
    ") -> pl.DataFrame:\n",
    "    \"\"\"\n",
    "    Find observations nearby to a position in space (and optionally time).\n",
    "\n",
    "    Get a frame of all records that are within a maximum distance of the\n",
    "    provided point.\n",
    "\n",
    "    If filter_datetime is True, then only records from the same datetime will\n",
    "    be returned. If a specific filter is desired this should be performed\n",
    "    before calling this function and set filter_datetime to False.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    lon : float\n",
    "        The longitude of the position.\n",
    "    lat : float\n",
    "        The latitude of the position.\n",
    "    pool : polars.DataFrame\n",
    "        The pool of records to search. Can be pre-filtered and filter_datetime\n",
    "        set to False.\n",
    "    max_dist : float\n",
    "        Will return records that have distance to the point <= this value.\n",
    "    lon_col : str\n",
    "        Name of the longitude column in the pool DataFrame\n",
    "    lat_col : str\n",
    "        Name of the latitude column in the pool DataFrame\n",
    "    dt : datetime | None\n",
    "        Datetime of the record. Must be set if filter_datetime is True.\n",
    "    date_col : str | None\n",
    "        Name of the datetime column in the pool. Must be set if filter_datetime\n",
    "        is True.\n",
    "    dt_gap : timedelta | None\n",
    "        Allowed time-gap for records. Records that fall between\n",
    "        dt - dt_gap and dt + dt_gap will be returned. If not set then only\n",
    "        records at dt will be returned. Applies if filter_datetime is True.\n",
    "    filter_datetime : bool\n",
    "        Only return records at the same datetime record as the input value. If\n",
    "        assessing multiple points with different datetimes, hence calling this\n",
    "        function frequently it will be more efficient to partition the pool\n",
    "        first, then set this value to False and only input the subset of data.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    polars.DataFrame\n",
    "        Containing only records from the pool within max_dist of the input\n",
    "        point, optionally at the same datetime if filter_datetime is True.\n",
    "    \"\"\"\n",
    "    required_cols = [lon_col, lat_col]\n",
    "    check_cols(pool, required_cols, \"pool\")\n",
    "\n",
    "    if filter_datetime:\n",
    "        if not dt or not date_col:\n",
    "            raise ValueError(\n",
    "                \"'dt' and 'date_col' must be provided if 'filter_datetime' \"\n",
    "                + \"is True\"\n",
    "            )\n",
    "        if date_col not in pool.columns:\n",
    "            raise ValueError(f\"'date_col' value {date_col} not found in pool.\")\n",
    "        if not dt_gap:\n",
    "            pool = pool.filter(pl.col(date_col).eq(dt))\n",
    "        else:\n",
    "            pool = pool.filter(\n",
    "                pl.col(date_col).is_between(\n",
    "                    dt - dt_gap, dt + dt_gap, closed=\"both\"\n",
    "                )\n",
    "            )\n",
    "\n",
    "    return (\n",
    "        pool.with_columns(\n",
    "            [pl.lit(lon).alias(\"_lon\"), pl.lit(lat).alias(\"_lat\")]\n",
    "        )\n",
    "        .pipe(\n",
    "            haversine_df,\n",
    "            lon_col=lon_col,\n",
    "            lat_col=lat_col,\n",
    "            out_colname=\"_dist\",\n",
    "            lon2_col=\"_lon\",\n",
    "            lat2_col=\"_lat\",\n",
    "        )\n",
    "        .filter(pl.col(\"_dist\").le(max_dist))\n",
    "        .drop([\"_dist\", \"_lon\", \"_lat\"])\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8b9279ed-6f89-4423-8833-acd0b365eb7b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.33 ms ± 20.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "rec = random.choice(test_recs)\n",
    "nearby_ships(lon=rec.lon, lat=rec.lat, dt=rec.datetime, max_dist=dist, dt_gap=dt, date_col=\"datetime\", pool=df, filter_datetime=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d148f129-9d8c-4c46-8f01-3e9c1e93e81a",
   "metadata": {},
   "source": [
    "## Verify\n",
    "\n",
    "Check that records are the same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "11f3d73a-fbe5-4f27-88d8-d0d687bd0eac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.52 s, sys: 237 ms, total: 2.75 s\n",
      "Wall time: 2.65 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "dist = 250\n",
    "for _ in range(250):\n",
    "    rec = Record(*random.choice(df.rows()))\n",
    "    orig = nearby_ships(lon=rec.lon, lat=rec.lat, dt=rec.datetime, max_dist=dist, dt_gap=dt, date_col=\"datetime\", pool=df, filter_datetime=True)\n",
    "    tree = otree.nearby_points(rec, dist=dist, t_dist=dt)\n",
    "    if orig.height > 0:\n",
    "        if not tree:\n",
    "            print(rec)\n",
    "            print(\"NO TREE!\")\n",
    "            print(f\"{orig = }\")\n",
    "        else:\n",
    "            tree = pl.from_records([(r.lon, r.lat, r.datetime, r.uid) for r in tree], orient=\"row\").rename({\"column_0\": \"lon\", \"column_1\": \"lat\", \"column_2\": \"datetime\", \"column_3\": \"uid\"})\n",
    "            if tree.height != orig.height:\n",
    "                print(\"Tree and Orig Heights Do Not Match\")\n",
    "                print(f\"{orig = }\")\n",
    "                print(f\"{tree = }\")\n",
    "            else:\n",
    "                # tree = tree.with_columns(pl.col(\"uid\").str.slice(0, 6))\n",
    "                if not tree.sort(\"uid\").equals(orig.sort(\"uid\")):\n",
    "                    print(\"Tree and Orig Do Not Match\")\n",
    "                    print(f\"{orig = }\")\n",
    "                    print(f\"{tree = }\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1223529e-bfae-4b83-aba7-505d05e588d3",
   "metadata": {},
   "source": [
    "## Check -180/180 boundary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4c392292-2d9f-4301-afb5-019fde069a1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = otree.nearby_points(Record(179.5, -43.1, datetime(1900, 1, 14, 13)), dist=200, t_dist=timedelta(days=3))\n",
    "for o in out:\n",
    "    print(o)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "GeoSpatialTools",
   "language": "python",
   "name": "geospatialtools"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}