Commit d58f82b9 authored by Joseph Siddons's avatar Joseph Siddons
Browse files

fix: revert median partition change, split on index. Account for this in insert and delete methods.

......@@ -49,18 +49,18 @@ class KDTree:
points.sort(key=lambda p: getattr(p, self.variable))
split_index = n_points // 2
self.partition_value = getattr(points[split_index - 1], self.variable)
while (
split_index < n_points
and getattr(points[split_index], self.variable)
== self.partition_value
):
split_index += 1
# while (
# split_index < n_points
# and getattr(points[split_index], self.variable)
# == self.partition_value
# ):
# split_index += 1
self.split = True
# Left is <= median
# Left is points left of midpoint
self.child_left = KDTree(points[:split_index], depth + 1)
# Right is > median
# Right is points right of midpoint
self.child_right = KDTree(points[split_index:], depth + 1)
return None
......@@ -79,8 +79,25 @@ class KDTree:
if getattr(point, self.variable) < self.partition_value:
return self.child_left.insert(point)
else:
elif getattr(point, self.variable) > self.partition_value:
return self.child_right.insert(point)
else:
r, _ = self.query(point)
if point in r:
return False
self.child_left._insert(point)
return True
def _insert(self, point: Record) -> None:
"""Insert a point even if it already exists in the KDTree"""
if not self.split:
self.points.append(point)
return
if getattr(point, self.variable) <= self.partition_value:
self.child_left._insert(point)
else:
self.child_right._insert(point)
return
def delete(self, point: Record) -> bool:
"""Delete a Record from the KDTree. May unbalance the KDTree"""
......@@ -91,10 +108,13 @@ class KDTree:
except ValueError:
return False
if getattr(point, self.variable) < self.partition_value:
return self.child_left.delete(point)
else:
return self.child_right.delete(point)
if getattr(point, self.variable) <= self.partition_value:
if self.child_left.delete(point):
return True
if getattr(point, self.variable) >= self.partition_value:
if self.child_right.delete(point):
return True
return False
def query(self, point) -> tuple[list[Record], float]:
"""Find the nearest Record within the KDTree to a query Record"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment