Commit e606e155 authored by Joseph Siddons's avatar Joseph Siddons
Browse files

Merge branch '7-add-ci-cd' into 'main'

Resolve "Add CI/CD"

Closes #7

See merge request nocsurfaceprocesses/geospatialtools!25
No related merge requests found
Pipeline #267676 canceled with stages
[codespell]
# Ignore long base64 - e.g. images in notebooks
ignore-regex = [A-Za-z0-9+/]{100,}
skip = "./docs/_build"
count = true
quiet-level = 3
......@@ -25,7 +25,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
uv.lock
.DS_Store
# PyInstaller
# Usually these files are written by a python script from a template
......@@ -73,6 +73,7 @@ instance/
docs/_build/
docs/_static/
docs/_templates/
docs/*.tex
# PyBuilder
.pybuilder/
......
variables:
UV_VERSION: "0.5"
PYTHON_VERSION: "3.12"
BASE_LAYER: "bookworm-slim"
UV_CACHE_DIR: ".uv-cache"
UV_SYSTEM_PYTHON: "1"
stages:
- build
- test
- lint
pre-commit:
stage: build
image: python:3.11
script:
- pip install pre-commit
- pre-commit run --all-files
rules:
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
- if: $CI_COMMIT_BRANCH == "main"
pytest:
stage: test
image: ghcr.io/astral-sh/uv:$UV_VERSION-python$PYTHON_VERSION-$BASE_LAYER
cache:
- key:
files:
- uv.lock
paths:
- $UV_CACHE_DIR
script:
- uv --version
- uv sync --all-extras
- uv pip list
- uv run pytest test
rules:
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
- if: $CI_COMMIT_BRANCH == "main"
ruff:
stage: lint
interruptible: true
image:
name: ghcr.io/astral-sh/ruff:0.9.9-alpine
before_script:
- cd $CI_PROJECT_DIR
- ruff --version
script:
- ruff check --output-format=gitlab > code-quality-report.json
artifacts:
reports:
codequality: $CI_PROJECT_DIR/code-quality-report.json
rules:
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
- if: $CI_COMMIT_BRANCH == "main"
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.5
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
hooks:
- id: codespell
......@@ -6,6 +6,7 @@ Contributors to this version: Joseph Siddons (@josidd)
### Internal changes
* Added CI/CD scripts for GitLab (!25).
* Added changelog (!26).
## 0.11.2 (2025-02-27)
......
......@@ -15,3 +15,102 @@ Please write tests for your code, add these to the `test` directory. I use `pyte
## Issues
Please file issues as they arise. Describe the problem, the steps to reproduce, and provide any output.
# Contributing to `GeoSpatialTools`
To contribute to this package you will need to be a member of NOC GitLab & have access permissions
to this repository - if you're able to read this then you have access!
If you wish to contribute please make sure you are working on your own branches (not main), ideally
you should work on your own fork. If you wish to work on a particular module you could name your
branch `module-user` where `module` would be replaced by the name of the module you are working on,
and `user` would be your user name. However you can name your branch as you see fit, but it is a
good idea to name it something that relates to what you are working on. If you are working on an
issue please reference the issue number in the branch name and associated Merge request. It is
generally easier to make a merge request and create a branch from the issue.
If you wish to merge to `main` please create a merge request and assign it to `josidd`,
and/or `ricorne` - either to perform the merge and/or review/approve the request. Please provide a
summary of the main changes that you have made so that there is context for us to review the
changes.
## Changelog
The changelog is `CHANGES.md`. Please add your changes to the changelog in your merge request.
## Commit Messages
We are trying to use a consistent and informative approach for writing commit messages in this
repository. We have adopted the [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/)
standard for commit messages. Whilst we won't enforce this standard upon others, we do recommend the
approach. Otherwise please ensure that your messages are descriptive and not just `changes` or
similar.
## Development Instructions
We recommend [uv](https://docs.astral.sh/uv/) for development purposes.
Clone the repository and create your development branch
```bash
git clone git@git.noc.ac.uk:noc_surface_processes/geospatialtools.git /path/to/geospatialtools
cd /path/to/geospatialtools
git checkout -b new-branch-name # if not a new branch exclude the '-b'
```
Create a virtual environment and install the dependencies
```bash
uv venv --python 3.12 # recommended version >= 3.9 is supported
source .venv/bin/activate # assuming bash or zsh
```
To install the dependencies run:
```bash
uv sync
```
Or to install all development dependencies and dependencies run:
```bash
uv sync --extra all --dev
```
## Standards
We recommend the use of [ruff](https://docs.astral.sh/ruff/) as a linter/formatter. The
`pyproject.toml` file includes all the settings for `ruff` for `GeoSpatialTools`.
```bash
uvx ruff check
uvx ruff check --fix
uvx ruff format
```
[codespell](https://github.com/codespell-project/codespell) is also used to check spelling/bad
names.
We use [pre-commit](https://pre-commit.com/) as part of out CI/CD processing.
## Tests
If you create new functionality please write and perform unit-tests on your code. The current
implementation of `GeoSpatialTools` uses the `pytest` library.
New tests do not need to be comprehensive, but I likely won't merge if your changes fails testing,
especially the pre-existing tests. You will need to include (and reference) any data that is
needed for testing.
We have a CI/CD pipeline that will automatically implement testing as part of merge requests.
We welcome additions/improvements to the current tests. New python test files should be placed in
the `test` directory and filenames must be prefixed with `test_`.
To perform tests you will need to have the environment set-up and active. Then run:
```
uv run pytest test/test_*.py
```
from the main/top directory for the repository.
......@@ -6,6 +6,7 @@ comparisons between GreatCircle objects.
"""
import numpy as np
from typing import Optional, Tuple
from .distance_metrics import bearing, gcd_slc
......@@ -15,7 +16,7 @@ def cartesian_to_lonlat(
y: float,
z: float,
to_radians: bool = False,
) -> tuple[float, float]:
) -> Tuple[float, float]:
"""
Get lon, and lat from cartesian coordinates.
......@@ -54,7 +55,7 @@ def polar_to_cartesian(
R: float = 6371,
to_radians: bool = True,
normalised: bool = True,
) -> tuple[float, float, float]:
) -> Tuple[float, float, float]:
"""
Convert from polars coordinates to cartesian.
......@@ -196,7 +197,7 @@ class GreatCircle:
def intersection(
self, other: object, epsilon: float = 0.01
) -> tuple[float, float] | None:
) -> Optional[Tuple[float, float]]:
"""
Determine intersection position with another GreatCircle.
......@@ -241,7 +242,7 @@ class GreatCircle:
self,
other: object,
epsilon: float = 0.01,
) -> float | None:
) -> Optional[float]:
"""
Get angle of intersection with another GreatCircle.
......
......@@ -20,7 +20,7 @@ class KDTree:
This implementation is a _balanced_ KDTree, each leaf node should have the
same number of points (or differ by 1 depending on the number of points
the KDTree is intialised with).
the KDTree is initialised with).
The KDTree partitions in each of the lon and lat dimensions alternatively
in sequence by splitting at the median of the dimension of the points
......@@ -34,7 +34,7 @@ class KDTree:
The current depth of the KDTree, you should set this to 0, it is used
internally.
max_depth : int
The maximium depth of the KDTree. The leaf nodes will have depth no
The maximum depth of the KDTree. The leaf nodes will have depth no
larger than this value. Leaf nodes will not be created if there is
only 1 point in the branch.
"""
......
"""
OctTree
-------
Constuctors for OctTree classes that can decrease the number of comparisons
Constructors for OctTree classes that can decrease the number of comparisons
for detecting nearby records for example. This is an implementation that uses
Haversine distances for comparisons between records for identification of
neighbours.
......@@ -18,14 +18,14 @@ from warnings import warn
class SpaceTimeRecord:
"""
ICOADS Record class.
SpaceTimeRecord class.
This is a simple instance of an ICOARDS record, it requires position and
temporal data. It can optionally include a UID and extra data.
This is a simple instance of a record, it requires position and temporal
data. It can optionally include a UID and extra data.
The temporal component was designed to use `datetime` values, however all
methods will work with numeric datetime information - for example a pentad,
timestamp, julian day, etc. Note that any uses within an OctTree and
time-stamp, Julian day, etc. Note that any uses within an OctTree and
SpaceTimeRectangle must also have timedelta values replaced with numeric
ranges in this case.
......
"""
QuadTree
--------
Constuctors for QuadTree classes that can decrease the number of comparisons
Constructors for QuadTree classes that can decrease the number of comparisons
for detecting nearby records for example. This is an implementation that uses
Haversine distances for comparisons between records for identification of
neighbours.
......@@ -17,9 +17,9 @@ from math import degrees, sqrt
class Record:
"""
ICOADS Record class
Record class
This is a simple instance of an ICOARDS record, it requires position data.
This is a simple instance of a record, it requires position data.
It can optionally include datetime, a UID, and extra data passed as
keyword arguments.
......
......@@ -8,6 +8,6 @@ class LatitudeError(ValueError):
class DateWarning(Warning):
"""Warnning for Datetime Value"""
"""Warning for Datetime Value"""
pass
......@@ -393,5 +393,3 @@ the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.
......@@ -81,11 +81,11 @@ N_samples = 1000
records: list[Record] = [Record(choice(lon_range), choice(lat_range)) for _ in range(N_samples)]
# Construct Tree
kt = KDTree(records)
kdtree = KDTree(records)
test_value: Record = Record(lon=47.6, lat=-31.1)
neighbours: list[Record] = []
neighbours, dist = kt.query(test_value)
neighbours, dist = kdtree.query(test_value)
```
### Points within distance (2d \& 3d)
......@@ -126,16 +126,16 @@ N_samples = 1000
# Construct Tree
boundary = Rectangle(-180, 180, -90, 90) # Full domain
qt = QuadTree(boundary)
quadtree = QuadTree(boundary)
records: list[Record] = [Record(choice(lon_range), choice(lat_range)) for _ in range(N_samples)]
for record in records:
qt.insert(record)
quadtree.insert(record)
test_value: Record = Record(lon=47.6, lat=-31.1)
dist: float = 340 # km
neighbours: list[Record] = qt.nearby_points(test_value, dist)
neighbours: list[Record] = quadtree.nearby_points(test_value, dist)
```
#### OctTree - 3d QuadTree
......@@ -190,16 +190,16 @@ N_samples = 1000
# Construct Tree
boundary = SpaceTimeRectangle(-180, 180, -90, 90, datetime(2009, 1, 1, 0), datetime(2009, 1, 2, 23)) # Full domain
ot = OctTree(boundary)
octtree = OctTree(boundary)
records: list[SpaceTimeRecord] = [
SpaceTimeRecord(choice(lon_range), choice(lat_range), choice(dates)) for _ in range(N_samples)]
for record in records:
ot.insert(record)
octtree.insert(record)
test_value: SpaceTimeRecord = SpaceTimeRecord(lon=47.6, lat=-31.1, datetime=datetime(2009, 1, 23, 17, 41))
dist: float = 340 # km
t_dist = timedelta(hours=4)
neighbours: list[Record] = ot.nearby_points(test_value, dist, t_dist)
neighbours: list[Record] = octtree.nearby_points(test_value, dist, t_dist)
```
No preview for this file type
This diff is collapsed.
......@@ -9,8 +9,8 @@ Welcome to GeoSpatialTool's documentation!
.. toctree::
:maxdepth: 4
:caption: Contents:
introduction
introduction
getting_started
authors
users_guide
......
{
"cells": [
{
"cell_type": "markdown",
"id": "f7143f08-1d06-4e94-bbf6-ef35ddd11556",
"metadata": {},
"source": [
"# KDTree\n",
"\n",
"Testing the time to look-up nearby records with the `KDTree` implementation. Note that this implementation is actually a `2DTree` since it can only compute a valid distance comparison between longitude and latitude positions.\n",
"\n",
"The `KDTree` object is used for finding the closest neighbour to a position, in this implementation we use the Haversine distance to compare positions."
]
},
{
"cell_type": "code",
"execution_count": 1,
......@@ -8,11 +20,10 @@
"outputs": [],
"source": [
"import os\n",
"import gzip\n",
"\n",
"os.environ[\"POLARS_MAX_THREADS\"] = \"4\"\n",
"\n",
"from datetime import datetime, timedelta\n",
"from random import choice\n",
"from datetime import datetime\n",
"from string import ascii_letters, digits\n",
"import random\n",
"import inspect\n",
......@@ -20,7 +31,17 @@
"import polars as pl\n",
"import numpy as np\n",
"\n",
"from GeoSpatialTools import Record, haversine, KDTree"
"from GeoSpatialTools import Record, KDTree"
]
},
{
"cell_type": "markdown",
"id": "ec6c6e7f-8eee-47ea-a5e9-12537bb3412d",
"metadata": {},
"source": [
"## Set-up functions\n",
"\n",
"Used for generating data, or for comparisons by doing brute-force approach."
]
},
{
......@@ -31,6 +52,7 @@
"outputs": [],
"source": [
"def randnum() -> float:\n",
" \"\"\"Get a random number between -1 and 1\"\"\"\n",
" return 2 * (np.random.rand() - 0.5)"
]
},
......@@ -42,6 +64,7 @@
"outputs": [],
"source": [
"def generate_uid(n: int) -> str:\n",
" \"\"\"Generates a pseudo uid by randomly selecting from characters\"\"\"\n",
" chars = ascii_letters + digits\n",
" return \"\".join(random.choice(chars) for _ in range(n))"
]
......@@ -49,6 +72,179 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "9e647ecd-abdc-46a0-8261-aa081fda2e1d",
"metadata": {
"jupyter": {
"source_hidden": true
},
"scrolled": true
},
"outputs": [],
"source": [
"def check_cols(\n",
" df: pl.DataFrame | pl.LazyFrame,\n",
" cols: list[str],\n",
" var_name: str = \"dataframe\",\n",
") -> None:\n",
" \"\"\"\n",
" Check that a dataframe contains a list of columns. Raises an error if not.\n",
"\n",
" Parameters\n",
" ----------\n",
" df : polars Frame\n",
" Dataframe to check\n",
" cols : list[str]\n",
" Required columns\n",
" var_name : str\n",
" Name of the Frame - used for displaying in any error.\n",
" \"\"\"\n",
" calling_func = inspect.stack()[1][3]\n",
" if isinstance(df, pl.DataFrame):\n",
" have_cols = df.columns\n",
" elif isinstance(df, pl.LazyFrame):\n",
" have_cols = df.collect_schema().names()\n",
" else:\n",
" raise TypeError(\"Input Frame is not a polars Frame\")\n",
"\n",
" cols_in_frame = intersect(cols, have_cols)\n",
" missing = [c for c in cols if c not in cols_in_frame]\n",
"\n",
" if len(missing) > 0:\n",
" err_str = f\"({calling_func}) - {var_name} missing required columns. \"\n",
" err_str += f\"Require: {', '.join(cols)}. \"\n",
" err_str += f\"Missing: {', '.join(missing)}.\"\n",
" raise ValueError(err_str)\n",
"\n",
" return\n",
"\n",
"\n",
"def haversine_df(\n",
" df: pl.DataFrame | pl.LazyFrame,\n",
" lon: float,\n",
" lat: float,\n",
" R: float = 6371,\n",
" lon_col: str = \"lon\",\n",
" lat_col: str = \"lat\",\n",
") -> pl.DataFrame | pl.LazyFrame:\n",
" \"\"\"\n",
" Compute haversine distance on earth surface between lon-lat positions\n",
" in a polars DataFrame and a lon-lat position.\n",
"\n",
" Parameters\n",
" ----------\n",
" df : polars.DataFrame\n",
" The data, containing required columns:\n",
" * lon_col\n",
" * lat_col\n",
" * date_var\n",
" lon : float\n",
" The longitude of the position.\n",
" lat : float\n",
" The latitude of the position.\n",
" R : float\n",
" Radius of earth in km\n",
" lon_col : str\n",
" Name of the longitude column\n",
" lat_col : str\n",
" Name of the latitude column\n",
"\n",
" Returns\n",
" -------\n",
" polars.DataFrame\n",
" With additional column specifying distances between consecutive points\n",
" in the same units as 'R'. With colname defined by 'out_colname'.\n",
" \"\"\"\n",
" required_cols = [lon_col, lat_col]\n",
"\n",
" check_cols(df, required_cols, \"df\")\n",
" return (\n",
" df.with_columns(\n",
" [\n",
" pl.col(lat_col).radians().alias(\"_lat0\"),\n",
" pl.lit(lat).radians().alias(\"_lat1\"),\n",
" (pl.col(lon_col) - lon).radians().alias(\"_dlon\"),\n",
" (pl.col(lat_col) - lat).radians().alias(\"_dlat\"),\n",
" ]\n",
" )\n",
" .with_columns(\n",
" (\n",
" (pl.col(\"_dlat\") / 2).sin().pow(2)\n",
" + pl.col(\"_lat0\").cos()\n",
" * pl.col(\"_lat1\").cos()\n",
" * (pl.col(\"_dlon\") / 2).sin().pow(2)\n",
" ).alias(\"_a\")\n",
" )\n",
" .with_columns(\n",
" (2 * R * (pl.col(\"_a\").sqrt().arcsin())).round(2).alias(\"_dist\")\n",
" )\n",
" .drop([\"_lat0\", \"_lat1\", \"_dlon\", \"_dlat\", \"_a\"])\n",
" )\n",
"\n",
"\n",
"def intersect(a, b) -> set:\n",
" \"\"\"Intersection of a and b, items in both a and b\"\"\"\n",
" return set(a) & set(b)\n",
"\n",
"\n",
"def nearest_ship(\n",
" lon: float,\n",
" lat: float,\n",
" df: pl.DataFrame,\n",
" lon_col: str = \"lon\",\n",
" lat_col: str = \"lat\",\n",
") -> pl.DataFrame:\n",
" \"\"\"\n",
" Find the observation nearest to a position in space.\n",
"\n",
" Get a frame with only the records that is closest to the input point.\n",
"\n",
" Parameters\n",
" ----------\n",
" lon : float\n",
" The longitude of the position.\n",
" lat : float\n",
" The latitude of the position.\n",
" df : polars.DataFrame\n",
" The pool of records to search. Can be pre-filtered and filter_datetime\n",
" set to False.\n",
" lon_col : str\n",
" Name of the longitude column in the pool DataFrame\n",
" lat_col : str\n",
" Name of the latitude column in the pool DataFrame\n",
"\n",
" Returns\n",
" -------\n",
" polars.DataFrame\n",
" Containing only records from the pool within max_dist of the input\n",
" point, optionally at the same datetime if filter_datetime is True.\n",
" \"\"\"\n",
" required_cols = [lon_col, lat_col]\n",
" check_cols(df, required_cols, \"df\")\n",
"\n",
" return (\n",
" df.pipe(\n",
" haversine_df,\n",
" lon=lon,\n",
" lat=lat,\n",
" lon_col=lon_col,\n",
" lat_col=lat_col,\n",
" )\n",
" .filter(pl.col(\"_dist\").eq(pl.col(\"_dist\").min()))\n",
" .drop([\"_dist\"])\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "287bdc1d-1ecf-4c59-af95-d2dc639c6894",
"metadata": {},
"source": [
"## Initialise random data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c60b30de-f864-477a-a09a-5f1caa4d9b9a",
"metadata": {},
"outputs": [
......@@ -58,17 +254,17 @@
"text": [
"(16000, 2)\n",
"shape: (5, 2)\n",
"┌─────┬─────┐\n",
"│ lon ┆ lat │\n",
"│ --- ┆ --- │\n",
"│ i64 ┆ i64 │\n",
"╞═════╪═════╡\n",
"│ 127 ┆ 21 │\n",
"│ -148 ┆ 36 │\n",
"│ -46 ┆ -15 │\n",
"│ 104 ┆ 89 │\n",
"│ -57 ┆ -31 │\n",
"└─────┴─────┘\n"
"┌─────┬─────┐\n",
"│ lon ┆ lat │\n",
"│ --- ┆ --- │\n",
"│ i64 ┆ i64 │\n",
"╞═════╪═════╡\n",
"│ -26 ┆ -42 │\n",
"│ 109 ┆ -33 │\n",
"│ -87 ┆ -18 │\n",
"│ -94 ┆ -81 │\n",
"│ -94 ┆ 0 │\n",
"└─────┴─────┘\n"
]
}
],
......@@ -76,7 +272,12 @@
"N = 16_000\n",
"lons = pl.int_range(-180, 180, eager=True)\n",
"lats = pl.int_range(-90, 90, eager=True)\n",
"dates = pl.datetime_range(datetime(1900, 1, 1, 0), datetime(1900, 1, 31, 23), interval=\"1h\", eager=True)\n",
"dates = pl.datetime_range(\n",
" datetime(1900, 1, 1, 0),\n",
" datetime(1900, 1, 31, 23),\n",
" interval=\"1h\",\n",
" eager=True,\n",
")\n",
"\n",
"lons_use = lons.sample(N, with_replacement=True).alias(\"lon\")\n",
"lats_use = lats.sample(N, with_replacement=True).alias(\"lat\")\n",
......@@ -90,7 +291,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "875f2a67-49fe-476f-add1-b1d76c6cd8f9",
"metadata": {},
"outputs": [],
......@@ -98,9 +299,19 @@
"records = [Record(**r) for r in df.rows(named=True)]"
]
},
{
"cell_type": "markdown",
"id": "bd83330b-ef2c-478e-9a7b-820454d198bb",
"metadata": {},
"source": [
"## Initialise the `KDTree`\n",
"\n",
"There is an overhead to constructing a `KDTree` object, so performance improvement is only for multiple comparisons."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "1e883e5a-5086-4c29-aff2-d308874eae16",
"metadata": {},
"outputs": [
......@@ -108,8 +319,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 151 ms, sys: 360 ms, total: 511 ms\n",
"Wall time: 57.3 ms\n"
"CPU times: user 35 ms, sys: 1.5 ms, total: 36.5 ms\n",
"Wall time: 32.1 ms\n"
]
}
],
......@@ -118,9 +329,30 @@
"kt = KDTree(records)"
]
},
{
"cell_type": "markdown",
"id": "0a37ef06-2691-4e01-96a9-1c1ecd582599",
"metadata": {},
"source": [
"## Compare with brute force approach"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "365bbf30-7a93-438d-92b2-a3471f1e9249",
"metadata": {},
"outputs": [],
"source": [
"test_record = Record(\n",
" random.choice(range(-179, 180)) + randnum(),\n",
" random.choice(range(-89, 90)) + randnum(),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "69022ad1-5ec8-4a09-836c-273ef452451f",
"metadata": {},
"outputs": [
......@@ -128,19 +360,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"203 μs ± 4.56 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
"101 μs ± 3.52 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())\n",
"kt.query(test_record)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"id": "28031966-c7d0-4201-a467-37590118e851",
"metadata": {},
"outputs": [
......@@ -148,19 +379,45 @@
"name": "stdout",
"output_type": "stream",
"text": [
"8.87 ms ± 188 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"8.17 ms ± 38.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"test_record = Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum())\n",
"np.argmin([test_record.distance(p) for p in records])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 11,
"id": "09e0f923-ca49-47bf-8643-e0b3a6d0467c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8.22 ms ± 95.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"nearest_ship(lon=test_record.lon, lat=test_record.lat, df=df)"
]
},
{
"cell_type": "markdown",
"id": "f0359950-942d-45ea-8676-b22c8ce9e296",
"metadata": {},
"source": [
"## Verify that results are correct"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0d10b2ba-57b2-475c-9d01-135363423990",
"metadata": {},
"outputs": [
......@@ -168,8 +425,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 17.4 s, sys: 147 ms, total: 17.6 s\n",
"Wall time: 17.6 s\n"
"CPU times: user 16.3 s, sys: 74.7 ms, total: 16.4 s\n",
"Wall time: 16.4 s\n"
]
}
],
......@@ -177,18 +434,28 @@
"%%time\n",
"n_samples = 1000\n",
"tol = 1e-8\n",
"test_records = [Record(random.choice(range(-179, 180)) + randnum(), random.choice(range(-89, 90)) + randnum()) for _ in range(n_samples)]\n",
"test_records = [\n",
" Record(\n",
" random.choice(range(-179, 180)) + randnum(),\n",
" random.choice(range(-89, 90)) + randnum(),\n",
" )\n",
" for _ in range(n_samples)\n",
"]\n",
"kd_res = [kt.query(r) for r in test_records]\n",
"kd_recs = [_[0][0] for _ in kd_res]\n",
"kd_dists = [_[1] for _ in kd_res]\n",
"tr_recs = [records[np.argmin([r.distance(p) for p in records])] for r in test_records]\n",
"tr_recs = [\n",
" records[np.argmin([r.distance(p) for p in records])] for r in test_records\n",
"]\n",
"tr_dists = [min([r.distance(p) for p in records]) for r in test_records]\n",
"assert all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]), \"NOT MATCHING?\""
"\n",
"if not all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]):\n",
" raise ValueError(\"NOT MATCHING?\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 13,
"id": "a6aa6926-7fd5-4fff-bd20-7bc0305b948d",
"metadata": {},
"outputs": [
......@@ -214,7 +481,7 @@
"└──────────┴──────────┴─────────┴────────┴────────┴─────────┴────────┴────────┘"
]
},
"execution_count": 10,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
......@@ -229,23 +496,25 @@
"tr_lons = [r.lon for r in tr_recs]\n",
"tr_lats = [r.lat for r in tr_recs]\n",
"\n",
"df = pl.DataFrame({\n",
" \"test_lon\": test_lons, \n",
" \"test_lat\": test_lats,\n",
" \"kd_dist\": kd_dists,\n",
" \"kd_lon\": kd_lons,\n",
" \"kd_lat\": kd_lats,\n",
" \"tr_dist\": tr_dists,\n",
" \"tr_lon\": tr_lons,\n",
" \"tr_lat\": tr_lats, \n",
"}).filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n",
"df = pl.DataFrame(\n",
" {\n",
" \"test_lon\": test_lons,\n",
" \"test_lat\": test_lats,\n",
" \"kd_dist\": kd_dists,\n",
" \"kd_lon\": kd_lons,\n",
" \"kd_lat\": kd_lats,\n",
" \"tr_dist\": tr_dists,\n",
" \"tr_lon\": tr_lons,\n",
" \"tr_lat\": tr_lats,\n",
" }\n",
").filter((pl.col(\"kd_dist\") - pl.col(\"tr_dist\")).abs().ge(tol))\n",
"df"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "GeoSpatialTools",
"display_name": "geospatialtools",
"language": "python",
"name": "geospatialtools"
},
......@@ -259,7 +528,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.11.11"
}
},
"nbformat": 4,
......
This diff is collapsed.
......@@ -33,22 +33,19 @@ classifiers = [
[project.optional-dependencies]
notebooks = ["ipykernel", "polars"]
test = ["pytest>=8.3.4"]
docs = [
"sphinx>=8.2.1",
"sphinx-autodoc-typehints>=3.1.0",
"sphinx-rtd-theme>=3.0.2",
]
test = ["pytest"]
docs = ["sphinx", "sphinx-autodoc-typehints", "sphinx-rtd-theme"]
all = [
"ipykernel",
"polars",
"pytest",
"sphinx>=8.2.1",
"sphinx-autodoc-typehints>=3.1.0",
"sphinx-rtd-theme>=3.0.2",
"sphinx",
"sphinx-autodoc-typehints",
"sphinx-rtd-theme",
]
[tool.ruff]
src = ["geospatialtools"]
line-length = 80
indent-width = 4
target-version = "py311"
......@@ -75,6 +72,11 @@ select = [
"W", # pycodestyle warnings
]
[tool.ruff.lint.per-file-ignores]
"docs/*.py" = ["D100", "D101", "D102", "D103"]
"test/**/*test*.py" = ["D100", "D101", "D102", "D103", "N802", "S101", "S311"]
"notebooks/*.ipynb" = ["D100", "D101", "D102", "D103", "N802", "S101", "S311"]
[tool.ruff.format]
quote-style = "double" # Like Black, use double quotes for strings.
indent-style = "space" # Like Black, indent with spaces, rather than tabs.
......
......@@ -2,7 +2,7 @@ import unittest
import random
from numpy import min, argmin
from GeoSpatialTools import haversine, KDTree, Record
from GeoSpatialTools import KDTree, Record
class TestKDTree(unittest.TestCase):
......@@ -87,7 +87,7 @@ class TestKDTree(unittest.TestCase):
kt.insert(Record(45, -21, uid="1"))
kt.insert(Record(45, -21, uid="2"))
r, d = kt.query(Record(44, -21, uid="3"))
r, _ = kt.query(Record(44, -21, uid="3"))
assert len(r) == 2
def test_wrap(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment