## OctTree!

Testing the time to look-up nearby records with the PyCOADS OctTree implementation.

In [1]:
import os
import gzip
os.environ["POLARS_MAX_THREADS"] = "4"

from datetime import datetime, timedelta
from random import choice
from string import ascii_letters, digits
import random
import inspect

import polars as pl
import numpy as np

from GeoSpatialTools.octtree import OctTree, SpaceTimeRecord as Record, SpaceTimeRectangle as Rectangle

In [2]:
def generate_uid(n: int) -> str:
 chars = ascii_letters + digits
 return "".join(random.choice(chars) for _ in range(n))

In [3]:
N = 16_000
lons = pl.int_range(-180, 180, eager=True)
lats = pl.int_range(-90, 90, eager=True)
dates = pl.datetime_range(datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23), interval="1h", eager=True)

lons_use = lons.sample(N, with_replacement=True).alias("lon")
lats_use = lats.sample(N, with_replacement=True).alias("lat")
dates_use = dates.sample(N, with_replacement=True).alias("datetime")
uids = pl.Series("uid", [generate_uid(8) for _ in range(N)])

df = pl.DataFrame([lons_use, lats_use, dates_use, uids]).unique()

## Add extra rows

For testing larger datasets. Uncomment to use.

In [4]:
# _df = df.clone()
# for i in range(100):
# df2 = pl.DataFrame([
# _df["lon"].shuffle(),
# _df["lat"].shuffle(),
# _df["datetime"].shuffle(),
# _df["uid"].shuffle(),
# ]).with_columns(pl.concat_str([pl.col("uid"), pl.lit(f"{i:03d}")]).alias("uid"))
# df = df.vstack(df2)
# df.shape
# df

## Intialise the OctTree Object

In [5]:
otree = OctTree(Rectangle(-180, 180, -90, 90, datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23)), capacity = 10, max_depth = 25)

In [6]:
%%time
for r in df.rows():
 otree.insert(Record(*r))

CPU times: user 106 ms, sys: 3.98 ms, total: 110 ms
Wall time: 109 ms


In [7]:
s = str(otree)
print("\n".join(s.split("\n")[:100]))

OctTree:
- boundary: SpaceTimeRectangle(west=-180, east=180, south=-90, north=90, start=datetime.datetime(1900, 1, 1, 0, 0), end=datetime.datetime(1900, 1, 31, 23, 0))
- capacity: 10
- depth: 0
- max_depth: 25
- contents:
- number of elements: 10
 * SpaceTimeRecord(x = 92, y = 15, datetime = 1900-01-17 08:00:00, uid = HRF401hH)
 * SpaceTimeRecord(x = -35, y = 37, datetime = 1900-01-04 08:00:00, uid = CXZaSOdh)
 * SpaceTimeRecord(x = 84, y = -7, datetime = 1900-01-07 16:00:00, uid = 2aEjxGwG)
 * SpaceTimeRecord(x = 68, y = 73, datetime = 1900-01-18 17:00:00, uid = Ah7lanWB)
 * SpaceTimeRecord(x = -179, y = 40, datetime = 1900-01-01 11:00:00, uid = HGxSJzf4)
 * SpaceTimeRecord(x = -73, y = 23, datetime = 1900-01-09 12:00:00, uid = qHQ8opO9)
 * SpaceTimeRecord(x = 117, y = -23, datetime = 1900-01-31 06:00:00, uid = ctvs56Fq)
 * SpaceTimeRecord(x = 109, y = 55, datetime = 1900-01-13 14:00:00, uid = C2xXIglD)
 * SpaceTimeRecord(x = 104, y = -10, datetime = 1900-01-06 16:00:00, uid = WEpQKIO

## Time Execution

Testing the identification of nearby points against the original full search

In [8]:
dts = pl.datetime_range(datetime(1900, 1, 1), datetime(1900, 2, 1), interval="1h", eager=True, closed="left")
N = dts.len()
lons = 180 - 360 * np.random.rand(N)
lats = 90 - 180 * np.random.rand(N)
test_df = pl.DataFrame({"lon": lons, "lat": lats, "datetime": dts})
test_recs = [Record(*r) for r in test_df.rows()]
dt = timedelta(days = 1)
dist = 350

In [9]:
%%timeit
otree.nearby_points(random.choice(test_recs), dist=dist, t_dist=dt)

207 μs ± 6.25 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
def check_cols(
 df: pl.DataFrame | pl.LazyFrame,
 cols: list[str],
 var_name: str = "dataframe",
) -> None:
 """
 Check that a dataframe contains a list of columns. Raises an error if not.

 Parameters
 ----------
 df : polars Frame
 Dataframe to check
 cols : list[str]
 Requried columns
 var_name : str
 Name of the Frame - used for displaying in any error.
 """
 calling_func = inspect.stack()[1][3]
 if isinstance(df, pl.DataFrame):
 have_cols = df.columns
 elif isinstance(df, pl.LazyFrame):
 have_cols = df.collect_schema().names()
 else:
 raise TypeError("Input Frame is not a polars Frame")

 cols_in_frame = intersect(cols, have_cols)
 missing = [c for c in cols if c not in cols_in_frame]

 if len(missing) > 0:
 err_str = f"({calling_func}) - {var_name} missing required columns. "
 err_str += f'Require: {", ".join(cols)}. '
 err_str += f'Missing: {", ".join(missing)}.'
 logging.error(err_str)
 raise ValueError(err_str)

 return


def haversine_df(
 df: pl.DataFrame | pl.LazyFrame,
 date_var: str = "datetime",
 R: float = 6371,
 reverse: bool = False,
 out_colname: str = "dist",
 lon_col: str = "lon",
 lat_col: str = "lat",
 lon2_col: str | None = None,
 lat2_col: str | None = None,
 sorted: bool = False,
 rev_prefix: str = "rev_",
) -> pl.DataFrame | pl.LazyFrame:
 """
 Compute haversine distance on earth surface between lon-lat positions.

 If only 'lon_col' and 'lat_col' are specified then this computes the
 distance between consecutive points. If a second set of positions is
 included via the optional 'lon2_col' and 'lat2_col' arguments then the
 distances between the columns are computed.

 Parameters
 ----------
 df : polars.DataFrame
 The data, containing required columns:
 * lon_col
 * lat_col
 * date_var
 date_var : str
 Name of the datetime column on which to sort the positions
 R : float
 Radius of earth in km
 reverse : bool
 Compute distances in reverse
 out_colname : str
 Name of the output column to store distances. Prefixed with 'rev_' if
 reverse is True
 lon_col : str
 Name of the longitude column
 lat_col : str
 Name of the latitude column
 lon2_col : str
 Name of the 2nd longitude column if present
 lat2_col : str
 Name of the 2nd latitude column if present
 sorted : bool
 Compute distances assuming that the frame is already sorted
 rev_prefix : str
 Prefix to use for colnames if reverse is True

 Returns
 -------
 polars.DataFrame
 With additional column specifying distances between consecutive points
 in the same units as 'R'. With colname defined by 'out_colname'.
 """
 required_cols = [lon_col, lat_col]

 if lon2_col is not None and lat2_col is not None:
 required_cols += [lon2_col, lat2_col]
 check_cols(df, required_cols, "df")
 return (
 df.with_columns(
 [
 pl.col(lat_col).radians().alias("_lat0"),
 pl.col(lat2_col).radians().alias("_lat1"),
 (pl.col(lon_col) - pl.col(lon2_col))
 .radians()
 .alias("_dlon"),
 (pl.col(lat_col) - pl.col(lat2_col))
 .radians()
 .alias("_dlat"),
 ]
 )
 .with_columns(
 (
 (pl.col("_dlat") / 2).sin().pow(2)
 + pl.col("_lat0").cos()
 * pl.col("_lat1").cos()
 * (pl.col("_dlon") / 2).sin().pow(2)
 ).alias("_a")
 )
 .with_columns(
 (2 * R * (pl.col("_a").sqrt().arcsin()))
 .round(2)
 .alias(out_colname)
 )
 .drop(["_lat0", "_lat1", "_dlon", "_dlat", "_a"])
 )

 if lon2_col is not None or lat2_col is not None:
 logging.warning(
 "(haversine_df) 2nd position incorrectly specified. "
 + "Calculating consecutive distances."
 )

 required_cols += [date_var]
 check_cols(df, required_cols, "df")
 if reverse:
 out_colname = rev_prefix + out_colname
 if not sorted:
 df = df.sort(date_var, descending=reverse)
 return (
 df.with_columns(
 [
 pl.col(lat_col).radians().alias("_lat0"),
 pl.col(lat_col).shift(n=-1).radians().alias("_lat1"),
 (pl.col(lon_col).shift(n=-1) - pl.col(lon_col))
 .radians()
 .alias("_dlon"),
 (pl.col(lat_col).shift(n=-1) - pl.col(lat_col))
 .radians()
 .alias("_dlat"),
 ]
 )
 .with_columns(
 (
 (pl.col("_dlat") / 2).sin().pow(2)
 + pl.col("_lat0").cos()
 * pl.col("_lat1").cos()
 * (pl.col("_dlon") / 2).sin().pow(2)
 ).alias("_a")
 )
 .with_columns(
 (2 * R * (pl.col("_a").sqrt().arcsin()))
 .round(2)
 .fill_null(strategy="forward")
 .alias(out_colname)
 )
 .drop(["_lat0", "_lat1", "_dlon", "_dlat", "_a"])
 )

def intersect(a, b) -> set:
 return set(a) & set(b)

def nearby_ships(
 lon: float,
 lat: float,
 pool: pl.DataFrame,
 max_dist: float,
 lon_col: str = "lon",
 lat_col: str = "lat",
 dt: datetime | None = None,
 date_col: str | None = None,
 dt_gap: timedelta | None = None,
 filter_datetime: bool = False,
) -> pl.DataFrame:
 """
 Find observations nearby to a position in space (and optionally time).

 Get a frame of all records that are within a maximum distance of the
 provided point.

 If filter_datetime is True, then only records from the same datetime will
 be returned. If a specific filter is desired this should be performed
 before calling this function and set filter_datetime to False.

 Parameters
 ----------
 lon : float
 The longitude of the position.
 lat : float
 The latitude of the position.
 pool : polars.DataFrame
 The pool of records to search. Can be pre-filtered and filter_datetime
 set to False.
 max_dist : float
 Will return records that have distance to the point <= this value.
 lon_col : str
 Name of the longitude column in the pool DataFrame
 lat_col : str
 Name of the latitude column in the pool DataFrame
 dt : datetime | None
 Datetime of the record. Must be set if filter_datetime is True.
 date_col : str | None
 Name of the datetime column in the pool. Must be set if filter_datetime
 is True.
 dt_gap : timedelta | None
 Allowed time-gap for records. Records that fall between
 dt - dt_gap and dt + dt_gap will be returned. If not set then only
 records at dt will be returned. Applies if filter_datetime is True.
 filter_datetime : bool
 Only return records at the same datetime record as the input value. If
 assessing multiple points with different datetimes, hence calling this
 function frequently it will be more efficient to partition the pool
 first, then set this value to False and only input the subset of data.

 Returns
 -------
 polars.DataFrame
 Containing only records from the pool within max_dist of the input
 point, optionally at the same datetime if filter_datetime is True.
 """
 required_cols = [lon_col, lat_col]
 check_cols(pool, required_cols, "pool")

 if filter_datetime:
 if not dt or not date_col:
 raise ValueError(
 "'dt' and 'date_col' must be provided if 'filter_datetime' "
 + "is True"
 )
 if date_col not in pool.columns:
 raise ValueError(f"'date_col' value {date_col} not found in pool.")
 if not dt_gap:
 pool = pool.filter(pl.col(date_col).eq(dt))
 else:
 pool = pool.filter(
 pl.col(date_col).is_between(
 dt - dt_gap, dt + dt_gap, closed="both"
 )
 )

 return (
 pool.with_columns(
 [pl.lit(lon).alias("_lon"), pl.lit(lat).alias("_lat")]
 )
 .pipe(
 haversine_df,
 lon_col=lon_col,
 lat_col=lat_col,
 out_colname="_dist",
 lon2_col="_lon",
 lat2_col="_lat",
 )
 .filter(pl.col("_dist").le(max_dist))
 .drop(["_dist", "_lon", "_lat"])
 )


In [11]:
%%timeit
rec = random.choice(test_recs)
nearby_ships(lon=rec.lon, lat=rec.lat, dt=rec.datetime, max_dist=dist, dt_gap=dt, date_col="datetime", pool=df, filter_datetime=True)

5.36 ms ± 164 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Verify

Check that records are the same

In [12]:
%%time
dist = 250
for _ in range(250):
 rec = Record(*random.choice(df.rows()))
 orig = nearby_ships(lon=rec.lon, lat=rec.lat, dt=rec.datetime, max_dist=dist, dt_gap=dt, date_col="datetime", pool=df, filter_datetime=True)
 tree = otree.nearby_points(rec, dist=dist, t_dist=dt)
 if orig.height > 0:
 if not tree:
 print(rec)
 print("NO TREE!")
 print(f"{orig = }")
 else:
 tree = pl.from_records([(r.lon, r.lat, r.datetime, r.uid) for r in tree], orient="row").rename({"column_0": "lon", "column_1": "lat", "column_2": "datetime", "column_3": "uid"})
 if tree.height != orig.height:
 print("Tree and Orig Heights Do Not Match")
 print(f"{orig = }")
 print(f"{tree = }")
 else:
 # tree = tree.with_columns(pl.col("uid").str.slice(0, 6))
 if not tree.sort("uid").equals(orig.sort("uid")):
 print("Tree and Orig Do Not Match")
 print(f"{orig = }")
 print(f"{tree = }")

CPU times: user 2.52 s, sys: 253 ms, total: 2.78 s
Wall time: 2.66 s


## Check -180/180 boundary

In [13]:
out = otree.nearby_points(Record(179.5, -43.1, datetime(1900, 1, 14, 13)), dist=200, t_dist=timedelta(days=3))
for o in out:
 print(o)