In [1]:
import os
import gzip
os.environ["POLARS_MAX_THREADS"] = "4"

from datetime import datetime, timedelta
from random import choice
from string import ascii_letters, digits
import random
import inspect

import polars as pl
import numpy as np

from GeoSpatialTools import Record, haversine, KDTree

In [2]:
def randnum() -> float:
    return 2 * (np.random.rand() - 0.5)

In [3]:
def generate_uid(n: int) -> str:
    chars = ascii_letters + digits
    return "".join(random.choice(chars) for _ in range(n))

In [4]:
N = 16_000
lons = pl.int_range(-180, 180, eager=True)
lats = pl.int_range(-90, 90, eager=True)
dates = pl.datetime_range(datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23), interval="1h", eager=True)

lons_use = lons.sample(N, with_replacement=True).alias("lon")
lats_use = lats.sample(N, with_replacement=True).alias("lat")
# dates_use = dates.sample(N, with_replacement=True).alias("datetime")
# uids = pl.Series("uid", [generate_uid(8) for _ in range(N)])

df = pl.DataFrame([lons_use, lats_use])
print(df.shape)
print(df.head())

(16000, 2)
shape: (5, 2)
┌──────┬─────┐
│ lon  ┆ lat │
│ ---  ┆ --- │
│ i64  ┆ i64 │
╞══════╪═════╡
│ 127  ┆ 21  │
│ -148 ┆ 36  │
│ -46  ┆ -15 │
│ 104  ┆ 89  │
│ -57  ┆ -31 │
└──────┴─────┘


In [5]:
records = [Record(**r) for r in df.rows(named=True)]

In [6]:
%%time
kt = KDTree(records)

CPU times: user 151 ms, sys: 360 ms, total: 511 ms
Wall time: 57.3 ms


In [7]:
%%timeit
test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())
kt.query(test_record)

203 μs ± 4.56 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
%%timeit
test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())
np.argmin([test_record.distance(p) for p in records])

8.87 ms ± 188 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
%%time
n_samples = 1000
tol = 1e-8
test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]
kd_res = [kt.query(r) for r in test_records]
kd_recs = [_[0][0] for _ in kd_res]
kd_dists = [_[1] for _ in kd_res]
tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]
tr_dists = [min([r.distance(p) for p in records]) for r in test_records]
assert all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]), "NOT MATCHING?"

CPU times: user 17.4 s, sys: 147 ms, total: 17.6 s
Wall time: 17.6 s


In [10]:
test_lons = [r.lon for r in test_records]
test_lats = [r.lat for r in test_records]

kd_lons = [r.lon for r in kd_recs]
kd_lats = [r.lat for r in kd_recs]

tr_lons = [r.lon for r in tr_recs]
tr_lats = [r.lat for r in tr_recs]

df = pl.DataFrame({
    "test_lon": test_lons, 
    "test_lat": test_lats,
    "kd_dist": kd_dists,
    "kd_lon": kd_lons,
    "kd_lat": kd_lats,
    "tr_dist": tr_dists,
    "tr_lon": tr_lons,
    "tr_lat": tr_lats,   
}).filter((pl.col("kd_dist") - pl.col("tr_dist")).abs().ge(tol))
df

test_lon,test_lat,kd_dist,kd_lon,kd_lat,tr_dist,tr_lon,tr_lat
f64,f64,f64,i64,i64,f64,i64,i64
