diff --git a/GeoSpatialTools/octtree.py b/GeoSpatialTools/octtree.py index a82a2be623f46f99b355e85cf66817aca99683ee..029ff9ec82ab446b84869854ea3e4234dd9d6bb2 100644 --- a/GeoSpatialTools/octtree.py +++ b/GeoSpatialTools/octtree.py @@ -591,6 +591,41 @@ class OctTree: return True return False + def remove(self, point: SpaceTimeRecord) -> bool: + """ + Remove a SpaceTimeRecord from the OctTree if it is in the OctTree. + + Returns True if the SpaceTimeRecord is removed. + """ + if not self.boundary.contains(point): + return False + + if point in self.points: + self.points.remove(point) + return True + + if not self.divided: + return False + + if self.northwestback.remove(point): + return True + elif self.northeastback.remove(point): + return True + elif self.southwestback.remove(point): + return True + elif self.southeastback.remove(point): + return True + elif self.northwestfwd.remove(point): + return True + elif self.northeastfwd.remove(point): + return True + elif self.southwestfwd.remove(point): + return True + elif self.southeastfwd.remove(point): + return True + + return False + def query( self, rect: SpaceTimeRectangle, diff --git a/GeoSpatialTools/quadtree.py b/GeoSpatialTools/quadtree.py index d24b4befef9cb22659f0ece485b67ee8f8b0e564..214bc38bf9976b0f06d02c59d0cc8880c400a12c 100644 --- a/GeoSpatialTools/quadtree.py +++ b/GeoSpatialTools/quadtree.py @@ -404,6 +404,33 @@ class QuadTree: return True return False + def remove(self, point: Record) -> bool: + """ + Remove a Record from the QuadTree if it is in the QuadTree. + + Returns True if the Record is removed. + """ + if not self.boundary.contains(point): + return False + + if point in self.points: + self.points.remove(point) + return True + + if not self.divided: + return False + + if self.northwest.remove(point): + return True + elif self.northeast.remove(point): + return True + elif self.southwest.remove(point): + return True + elif self.southeast.remove(point): + return True + + return False + def query( self, rect: Rectangle, diff --git a/pyproject.toml b/pyproject.toml index a9c82cda203f56d0d674b45d224faf30925af386..9cc4b05f2af436e55c30dcb93ac222cb02e46627 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ packages = ["GeoSpatialTools"] [project] name = "GeoSpatialTools" -version = "0.7.1" +version = "0.8.0" dependencies = [ "numpy", ] diff --git a/test/test_octtree.py b/test/test_octtree.py index ab3d388f0e2ae4306c7e64022d63d1577845e282..4a04eb25f97c8e14eb387ec2bdc1fe371df74b2a 100644 --- a/test/test_octtree.py +++ b/test/test_octtree.py @@ -223,6 +223,41 @@ class TestOctTree(unittest.TestCase): ] assert res == expected + def test_remove(self): + d = datetime(2023, 3, 24, 12, 0) + dt = timedelta(days=10) + start = d - dt + end = d + dt + + boundary = Rectangle(0, 20, 0, 8, start, end) + otree = OctTree(boundary, capacity=3) + points: list[Record] = [ + Record(10, 4, d, "main"), + Record(12, 1, d + timedelta(hours=3), "main2"), + Record(3, 7, d - timedelta(days=3), "main3"), + Record(13, 2, d + timedelta(hours=17), "southeastfwd"), + Record(3, 6, d - timedelta(days=1), "northwestback"), + Record(10, 4, d, "northwestback"), + Record(18, 2, d + timedelta(days=23), "not added"), + Record(11, 7, d + timedelta(hours=2), "northeastfwd"), + ] + to_remove = points[4] + for point in points: + otree.insert(point) + + # TEST: query works before remove + q_res = otree.nearby_points( + to_remove, dist=0.1, t_dist=timedelta(minutes=5) + ) + assert len(q_res) == 1 + + # TEST: point is removed and query fails + assert otree.remove(to_remove) + q_res = otree.nearby_points( + to_remove, dist=0.1, t_dist=timedelta(minutes=5) + ) + assert len(q_res) == 0 + def test_query(self): d = datetime(2023, 3, 24, 12, 0) dt = timedelta(days=10) diff --git a/test/test_quadtree.py b/test/test_quadtree.py index 223c892aa0d3cb44e94344f9684f7153012cb507..86a643e3d4b62b364d0e4577f59b3d2999f58aeb 100644 --- a/test/test_quadtree.py +++ b/test/test_quadtree.py @@ -103,6 +103,29 @@ class TestQuadTree(unittest.TestCase): ] assert res == expected + def test_remove(self): + boundary = Rectangle(0, 20, 0, 8) + qtree = QuadTree(boundary, capacity=3) + points: list[Record] = [ + Record(10, 5), + Record(19, 1), + Record(0, 0), + Record(-2, -9.2), + Record(12.8, 2.1), + ] + to_remove = points[2] + for point in points: + qtree.insert(point) + q_res = qtree.nearby_points(to_remove, dist=0.1) + + # TEST: get the point + assert len(q_res) == 1 + + # TEST: Point is removed + assert qtree.remove(to_remove) + q_res = qtree.nearby_points(to_remove, dist=0.1) + assert len(q_res) == 0 + def test_query(self): boundary = Rectangle(0, 20, 0, 8) qtree = QuadTree(boundary, capacity=3)