Source code for qupulse.utils.tree

"""This module contains a tree implementation."""

from typing import Iterable, Union, List, Generator, Tuple, TypeVar, Optional, Sequence
from collections import deque, namedtuple
import weakref

from qupulse.utils.types import SequenceProxy


__all__ = ['Node']


def make_empty_weak_reference() -> weakref.ref:
    return weakref.ref(lambda: None)


_NodeType = TypeVar('_NodeType', bound='Node')


[docs]class Node: debug = False __slots__ = ('__parent', '__children', '__parent_index', '__weakref__') def __init__(self: _NodeType, parent: Union[_NodeType, None]=None, children: Optional[Iterable]=None): self.__parent = make_empty_weak_reference() if parent is None else weakref.ref(parent) self.__children = [] if children is None else [self.parse_child(child) for child in children] self.__parent_index = None for i, child in enumerate(self.__children): self.__children[i].__parent_index = i
[docs] def parse_child(self: _NodeType, child) -> _NodeType: if isinstance(child, dict): return type(self)(parent=self, **child) elif type(child) is type(self): child.__parent = weakref.ref(self) return child else: raise TypeError('Invalid child type', type(child))
[docs] def is_leaf(self) -> bool: return len(self.__children) == 0
[docs] def depth(self) -> int: return 0 if self.is_leaf() else (1 + max(e.depth() for e in self.__children))
[docs] def is_balanced(self) -> bool: if self.is_leaf(): return True return all((e.depth() == self.__children[0].depth() and e.is_balanced()) for e in self.__children)
def __iter__(self: _NodeType) -> Iterable[_NodeType]: return iter(self.__children) def __reversed__(self: _NodeType) -> Iterable[_NodeType]: return reversed(self.__children) def __setitem__(self: _NodeType, idx: Union[int, slice], value: Union[_NodeType, Iterable[_NodeType]]): if isinstance(idx, slice): if isinstance(value, Node): raise TypeError('can only assign an iterable (Loop does not count)') value = tuple(self.parse_child(child) for child in value) indices = range(*idx.indices(len(self.__children))) self.__children.__setitem__(idx, value) if len(value) != len(indices): first_invalid = indices.start if indices.step > 0 else indices.stop for index in range(first_invalid, len(self)): self.__children[index].__parent_index = index elif len(value) > 0: for index in range(indices.start, indices.start + indices.step*len(value)): self.__children[index].__parent_index = index else: value = self.parse_child(value) value.__parent_index = idx self.__children.__setitem__(idx, value) def __getitem__(self: _NodeType, *args, **kwargs) ->Union[_NodeType, List[_NodeType]]: return self.__children.__getitem__(*args, **kwargs) def __len__(self) -> int: return len(self.__children)
[docs] def get_depth_first_iterator(self: _NodeType) -> Generator[_NodeType, None, None]: stack = [(self, self.__children)] while stack: node, children = stack.pop() if children: stack.append((node, None)) stack.extend((child, child.__children) for child in reversed(children)) else: yield node
[docs] def get_breadth_first_iterator(self: _NodeType) -> Generator[_NodeType, None, None]: queue = deque([self]) while queue: elem = queue.popleft() queue.extend(elem) yield elem
[docs] def assert_tree_integrity(self) -> None: if self.debug: for child in self.__children: if id(child.parent) != id(self): raise AssertionError('Child is missing parent reference') child.assert_tree_integrity() if self.parent: if self.__parent_index not in range(len(self.parent)): raise AssertionError('Out of range parent index') if id(self.parent[self.__parent_index]) != id(self): if id(self) in (id(c) for c in self.parent.__children): raise AssertionError('Wrong parent index') else: raise AssertionError('Parent is missing child reference')
@property def children(self: _NodeType) -> Sequence[_NodeType]: """ :return: shallow copy of children """ return SequenceProxy(self.__children) @property def parent(self: _NodeType) -> Union[None, _NodeType]: return self.__parent() @property def parent_index(self) -> int: return self.__parent_index
[docs] def get_root(self: _NodeType) -> _NodeType: if self.parent: return self.parent.get_root() else: return self
[docs] def get_location(self) -> Tuple[int, ...]: self.assert_tree_integrity() if self.parent: return (*self.parent.get_location(), self.__parent_index) else: return tuple()
[docs] def locate(self: _NodeType, location: Tuple[int, ...]) -> _NodeType: if location: return self.__children[location[0]].locate(location[1:]) else: return self
def _reverse_children(self): """Reverse children in-place""" self.__children.reverse()
def is_tree_circular(root: Node) -> Union[None, Tuple[List[Node], int]]: NodeStack = namedtuple('NodeStack', ['node', 'stack']) nodes_to_visit = deque((NodeStack(root, deque()), )) while nodes_to_visit: node, stack = nodes_to_visit.pop() stack.append(id(node)) for child in node: if id(child) in stack: return stack, id(child) nodes_to_visit.append((child, stack)) return None