diff --git a/rangetree.py b/rangetree.py index 48e14d1..92ebba8 100644 --- a/rangetree.py +++ b/rangetree.py @@ -220,3 +220,43 @@ def traverse_all(self) -> Iterator[TreeNode[K, V]]: yield from self.left.traverse_leaves() yield self yield from self.right.traverse_leaves() + + +@dataclass(frozen=True) +class LayeredTreeNode(TreeNode, Generic[K, V]): + """A 2D layered range tree.""" + keys: List[K] = None + dim1: int = None # dimension main tree is built on + dim2: int = None # dimension auxiliary arrays are built on + + def __post_init__(self, _key: Optional[K]): + assert self.dim1 is not None + assert self.dim2 is not None + super(LayeredTreeNode, self).__post_init__(_key) + + if _key is not None: + object.__setattr__(self, 'keys', [_key]) + else: + assert not self.is_leaf + # TODO: insert pointers here + + @classmethod + def create_internal( + cls, + left: LayeredTreeNode, + right: LayeredTreeNode, + value=None + ) -> LayeredTreeNode[K]: + return cls( + is_leaf=False, + left=left, + right=right, + value=value, + size=left.size + right.size, + dim1=left.dim1, + dim2=left.dim2 + ) + + # TODO: ideally, most of the fractional cascading magic would happen + # in __post_init__, etc.---we (hopefully) shouldn't have to change the + # create_* functions much.