Source code for qupulse._program._loop

from typing import Union, Dict, Iterable, Tuple, cast, List, Optional, Generator, Mapping
from collections import defaultdict
from enum import Enum
import warnings

import numpy as np


from qupulse._program.waveforms import Waveform
from qupulse._program.volatile import VolatileRepetitionCount, VolatileProperty

from qupulse.utils import is_integer
from qupulse.utils.types import TimeType, MeasurementWindow
from qupulse.utils.tree import Node, is_tree_circular

from qupulse._program.waveforms import SequenceWaveform, RepetitionWaveform

__all__ = ['Loop', 'make_compatible', 'MakeCompatibleWarning']


[docs]class Loop(Node): MAX_REPR_SIZE = 2000 __slots__ = ('_waveform', '_measurements', '_repetition_definition', '_cached_body_duration') """This class represents a initialized (sub-)program as a tree. Each Loop of a valid program has a repetition count and either a waveform or a sequence of loops as children. A Loop can have associated measurements which are also repeated. """
[docs] def __init__(self, parent: Union['Loop', None] = None, children: Iterable['Loop'] = (), waveform: Optional[Waveform] = None, measurements: Optional[List[MeasurementWindow]] = None, repetition_count: Union[int, VolatileRepetitionCount] = 1): """Initialize a new loop Args: parent: Forwarded to Node.__init__ children: Forwarded to Node.__init__ waveform: "Payload" measurements: Associated measurements repetition_count: The children / waveform are repeated this often """ super().__init__(parent=parent, children=children) self._waveform = waveform self._measurements = measurements self._repetition_definition = repetition_count self._cached_body_duration = None assert isinstance(repetition_count, VolatileRepetitionCount) or is_integer(repetition_count) assert isinstance(waveform, (type(None), Waveform))
def __eq__(self, other: 'Loop') -> bool: if type(self) == type(other): return (self._repetition_definition == other._repetition_definition and self.waveform == other.waveform and (self._measurements or None) == (other._measurements or None) and len(self) == len(other) and all(self_child == other_child for self_child, other_child in zip(self, other))) else: return NotImplemented
[docs] def append_child(self, loop: Optional['Loop'] = None, **kwargs) -> None: """Append a child to this loop. Either an existing Loop object or a newly created from kwargs Args: loop: loop to append **kwargs: Child is constructed with these kwargs Raises: ValueError: if called with loop and kwargs """ if loop is not None: if kwargs: raise ValueError("Cannot pass a Loop object and Loop constructor arguments at the same time in " "append_child") arg = (loop,) else: arg = (kwargs,) super().__setitem__(slice(len(self), len(self)), arg) self._invalidate_duration(body_duration_increment=self[-1].duration)
def _invalidate_duration(self, body_duration_increment=None): if self._cached_body_duration is not None: if body_duration_increment is not None: self._cached_body_duration += body_duration_increment else: self._cached_body_duration = None if self.parent: if body_duration_increment is not None: self.parent._invalidate_duration(body_duration_increment=body_duration_increment*self.repetition_count) else: self.parent._invalidate_duration()
[docs] def add_measurements(self, measurements: Iterable[MeasurementWindow]): """Add measurements offset by the current body duration i.e. to the END of the current loop Args: measurements: Measurements to add """ body_duration = float(self.body_duration) if body_duration == 0: measurements = measurements else: measurements = ((mw_name, begin+body_duration, length) for mw_name, begin, length in measurements) if self._measurements is None: self._measurements = list(measurements) else: self._measurements.extend(measurements)
@property def waveform(self) -> Waveform: return self._waveform @waveform.setter def waveform(self, val) -> None: self._waveform = val self._invalidate_duration() @property def body_duration(self) -> TimeType: if self._cached_body_duration is None: if self.is_leaf(): if self.waveform: self._cached_body_duration = self.waveform.duration else: self._cached_body_duration = TimeType.from_fraction(0, 1) else: self._cached_body_duration = sum(child.duration for child in self) return self._cached_body_duration @property def duration(self) -> TimeType: return self.body_duration * TimeType.from_fraction(self.repetition_count, 1) @property def volatile_repetition(self) -> Optional[VolatileProperty]: return getattr(self._repetition_definition, 'volatile_property', None) @property def repetition_definition(self) -> Union[int, VolatileRepetitionCount]: return self._repetition_definition @repetition_definition.setter def repetition_definition(self, new_definition: Union[int, VolatileRepetitionCount]): self._repetition_definition = new_definition @property def repetition_count(self) -> int: return int(self._repetition_definition) @repetition_count.setter def repetition_count(self, val: int) -> None: assert isinstance(val, (int, float)) new_repetition = int(val) if abs(new_repetition - val) > 1e-10: raise ValueError('Repetition count was not an integer') self._repetition_definition = new_repetition
[docs] def unroll(self) -> None: if self.is_leaf(): raise RuntimeError('Leaves cannot be unrolled') if self.volatile_repetition: warnings.warn("Unrolling a Loop with volatile repetition count", VolatileModificationWarning) i = self.parent_index self.parent[i:i+1] = (child.copy_tree_structure(new_parent=self.parent) for _ in range(self.repetition_count) for child in self) self.parent.assert_tree_integrity()
def __setitem__(self, idx, value): super().__setitem__(idx, value) self._invalidate_duration()
[docs] def unroll_children(self) -> None: if self.volatile_repetition: warnings.warn("Unrolling a Loop with volatile repetition count", VolatileModificationWarning) old_children = self.children self[:] = (child.copy_tree_structure() for _ in range(self.repetition_count) for child in old_children) self.repetition_count = 1 self.assert_tree_integrity()
[docs] def encapsulate(self) -> None: """Add a nesting level by moving self to its children.""" self[:] = [Loop(children=self, repetition_count=self._repetition_definition, waveform=self._waveform, measurements=self._measurements)] self.repetition_count = 1 self._waveform = None self._measurements = None self.assert_tree_integrity()
def _get_repr(self, first_prefix, other_prefixes) -> Generator[str, None, None]: if self.is_leaf(): yield '%sEXEC %r %d times' % (first_prefix, self._waveform, self.repetition_count) else: yield '%sLOOP %d times:' % (first_prefix, self.repetition_count) for elem in self: yield from cast(Loop, elem)._get_repr(other_prefixes + ' ->', other_prefixes + ' ') def __repr__(self) -> str: is_circular = is_tree_circular(self) if is_circular: return '{}: Circ {}'.format(id(self), is_circular) str_len = 0 repr_list = [] for sub_repr in self._get_repr('', ''): str_len += len(sub_repr) if self.MAX_REPR_SIZE and str_len > self.MAX_REPR_SIZE: repr_list.append('...') break else: repr_list.append(sub_repr) return '\n'.join(repr_list)
[docs] def copy_tree_structure(self, new_parent: Union['Loop', bool]=False) -> 'Loop': return type(self)(parent=self.parent if new_parent is False else new_parent, waveform=self._waveform, repetition_count=self._repetition_definition, measurements=None if self._measurements is None else list(self._measurements), children=(child.copy_tree_structure() for child in self))
def _get_measurement_windows(self) -> Mapping[str, np.ndarray]: """Private implementation of get_measurement_windows with a slightly different data format for easier tiling. Returns: A dictionary (measurement_name -> array) with begin == array[:, 0] and length == array[:, 1] """ temp_meas_windows = defaultdict(list) if self._measurements: for (mw_name, begin, length) in self._measurements: temp_meas_windows[mw_name].append((begin, length)) for mw_name, begin_length_list in temp_meas_windows.items(): temp_meas_windows[mw_name] = [np.asarray(begin_length_list, dtype=float)] # calculate duration together with meas windows in the same iteration if self.is_leaf(): body_duration = float(self.body_duration) else: offset = TimeType(0) for child in self: for mw_name, begins_length_array in child._get_measurement_windows().items(): begins_length_array[:, 0] += float(offset) temp_meas_windows[mw_name].append(begins_length_array) offset += child.duration body_duration = float(offset) # this gives us regular dict behaviour of the returned object temp_meas_windows.default_factory = None # repeat and add repetition based offset for mw_name, begin_length_list in temp_meas_windows.items(): temp_begin_length_array = np.concatenate(begin_length_list) begin_length_array = np.tile(temp_begin_length_array, (self.repetition_count, 1)) shaped_begin_length_array = np.reshape(begin_length_array, (self.repetition_count, -1, 2)) shaped_begin_length_array[:, :, 0] += (np.arange(self.repetition_count) * body_duration)[:, np.newaxis] temp_meas_windows[mw_name] = begin_length_array # the cast is here because static type analysis struggles to detect that we replace _all_ values by ndarray in # the previous loop return cast(Mapping[str, np.ndarray], temp_meas_windows)
[docs] def get_measurement_windows(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: """Iterates over all children and collect the begin and length arrays of each measurement window. Returns: A dictionary (measurement_name -> (begin, length)) with begin and length being :class:`numpy.ndarray` """ return {mw_name: (begin_length_list[:, 0], begin_length_list[:, 1]) for mw_name, begin_length_list in self._get_measurement_windows().items()}
[docs] def split_one_child(self, child_index=None) -> None: """Take the last child that has a repetition count larger one, decrease it's repetition count and insert a copy with repetition cout one after it""" if child_index is not None: if self[child_index].repetition_count < 2: raise ValueError('Cannot split child {} as the repetition count is not larger 1') else: # we cannot reverse enumerate n_child = len(self) - 1 for reverse_idx, child in enumerate(reversed(self)): if child.repetition_count > 1: forward_idx = n_child - reverse_idx if not child.volatile_repetition: child_index = forward_idx break elif child_index is None: child_index = forward_idx else: if child_index is None: raise RuntimeError('There is no child with repetition count > 1') if self[child_index].volatile_repetition: warnings.warn("Splitting a child with volatile repetition count", VolatileModificationWarning) new_child = self[child_index].copy_tree_structure() new_child.repetition_count = 1 self[child_index].repetition_count -= 1 self[child_index+1:child_index+1] = (new_child,) self.assert_tree_integrity()
[docs] def flatten_and_balance(self, depth: int) -> None: """Modifies the program so all tree branches have the same depth. Args: depth: Target depth of the program """ i = 0 while i < len(self): # only used by type checker sub_program = cast(Loop, self[i]) if sub_program.depth() < depth - 1: # increase nesting because the subprogram is not deep enough sub_program.encapsulate() elif not sub_program.is_balanced(): # balance the sub program. We revisit it in the next iteration (no change of i ) # because it might modify self. While writing this comment I am not sure this is true. 14.01.2020 Simon sub_program.flatten_and_balance(depth - 1) elif sub_program.depth() == depth - 1: # subprogram is balanced with the correct depth i += 1 elif sub_program._has_single_child_that_can_be_merged(): # subprogram is balanced but to deep and has no measurements -> we can "lift" the sub-sub-program # TODO: There was a len(sub_sub_program) == 1 check here that I cannot explain sub_program._merge_single_child() elif not sub_program.is_leaf(): # subprogram is balanced but too deep sub_program.unroll() else: # we land in this case if the function gets called with depth == 0 and the current subprogram is a leaf i += 1
def _has_single_child_that_can_be_merged(self) -> bool: if len(self) == 1: child = cast(Loop, self[0]) return not self._measurements or (child.repetition_count == 1 and not child.volatile_repetition) else: return False def _merge_single_child(self): """Lift the single child to current level. Requires _has_single_child_that_can_be_merged to be true""" assert len(self) == 1, "bug: _merge_single_child called on loop with len != 1" child = cast(Loop, self[0]) # if the child has a fixed repetition count of 1 the measurements can be merged mergable_measurements = child.repetition_count == 1 and not child.volatile_repetition assert not self._measurements or mergable_measurements, "bug: _merge_single_child called on loop with measurements" assert not self._waveform, "bug: _merge_single_child called on loop with children and waveform" measurements = child._measurements if self._measurements: if measurements: measurements.extend(self._measurements) else: measurements = self._measurements if not self.volatile_repetition and not child.volatile_repetition: # simple integer multiplication repetition_definition = self.repetition_count * child.repetition_count elif not self.volatile_repetition: repetition_definition = child._repetition_definition * self.repetition_count elif not child.volatile_repetition: repetition_definition = self._repetition_definition * child.repetition_count else: # create a new expression that depends on both expression = 'parent_repetition_count * child_repetition_count' repetition_definition = VolatileRepetitionCount.operation( expression=expression, parent_repetition_count=self._repetition_definition, child_repetition_count=child._repetition_definition) self[:] = iter(child) self._waveform = child._waveform self._repetition_definition = repetition_definition self._measurements = measurements self._invalidate_duration() return True
[docs] def cleanup(self, actions=('remove_empty_loops', 'merge_single_child')): """Apply the specified actions to cleanup the Loop. remove_empty_loops: Remove loops with no children and no waveform (a DroppedMeasurementWarning is issued) merge_single_child: see `_try_merge_single_child` documentation Warnings: DroppedMeasurementWarning: Likely a bug in qupulse. TODO: investigate whether there are usecases """ if 'remove_empty_loops' in actions: new_children = [] for child in self: child = cast(Loop, child) if child.is_leaf(): if child.waveform is None: if child._measurements: warnings.warn("Dropping measurement since there is no waveform attached", category=DroppedMeasurementWarning) else: new_children.append(child) else: child.cleanup(actions) if child.waveform or not child.is_leaf(): new_children.append(child) elif child._measurements: warnings.warn("Dropping measurement since there is no waveform in children", category=DroppedMeasurementWarning) if len(self) != len(new_children): self[:] = new_children else: # only do the recursive call for child in self: child.cleanup(actions) if 'merge_single_child' in actions and self._has_single_child_that_can_be_merged(): self._merge_single_child()
[docs] def get_duration_structure(self) -> Tuple[int, Union[TimeType, tuple]]: if self.is_leaf(): return self.repetition_count, self.waveform.duration else: return self.repetition_count, tuple(child.get_duration_structure() for child in self)
class ChannelSplit(Exception): def __init__(self, channel_sets): self.channel_sets = channel_sets def to_waveform(program: Loop) -> Waveform: if program.is_leaf(): if program.repetition_count == 1: return program.waveform else: return RepetitionWaveform(program.waveform, program.repetition_count) else: if len(program) == 1: sequenced_waveform = to_waveform(cast(Loop, program[0])) else: sequenced_waveform = SequenceWaveform([to_waveform(cast(Loop, sub_program)) for sub_program in program]) if program.repetition_count > 1: return RepetitionWaveform(sequenced_waveform, program.repetition_count) else: return sequenced_waveform class _CompatibilityLevel(Enum): compatible = 0 action_required = 1 incompatible_too_short = 2 incompatible_fraction = 3 incompatible_quantum = 4 def is_incompatible(self) -> bool: return self in (self.incompatible_fraction, self.incompatible_quantum, self.incompatible_too_short) def _is_compatible(program: Loop, min_len: int, quantum: int, sample_rate: TimeType) -> _CompatibilityLevel: """ check whether program loop is compatible with awg requirements possible reasons for incompatibility: program shorter than minimum length program duration not an integer program duration not a multiple of quantum """ program_duration_in_samples = program.duration * sample_rate if program_duration_in_samples.denominator != 1: return _CompatibilityLevel.incompatible_fraction if program_duration_in_samples < min_len: return _CompatibilityLevel.incompatible_too_short if program_duration_in_samples % quantum > 0: return _CompatibilityLevel.incompatible_quantum if program.is_leaf(): waveform_duration_in_samples = program.body_duration * sample_rate if waveform_duration_in_samples < min_len or (waveform_duration_in_samples / quantum).denominator != 1: if program.volatile_repetition: warnings.warn("_is_compatible requires an action which drops volatility.", category=VolatileModificationWarning) return _CompatibilityLevel.action_required else: return _CompatibilityLevel.compatible else: if all(_is_compatible(cast(Loop, sub_program), min_len, quantum, sample_rate) == _CompatibilityLevel.compatible for sub_program in program): return _CompatibilityLevel.compatible else: if program.volatile_repetition: warnings.warn("_is_compatible requires an action which drops volatility.", category=VolatileModificationWarning) return _CompatibilityLevel.action_required def _make_compatible(program: Loop, min_len: int, quantum: int, sample_rate: TimeType) -> None: if program.is_leaf(): program.waveform = to_waveform(program.copy_tree_structure()) program.repetition_count = 1 else: comp_levels = [_is_compatible(cast(Loop, sub_program), min_len, quantum, sample_rate) for sub_program in program] if any(comp_level.is_incompatible() for comp_level in comp_levels): single_run = program.duration * sample_rate / program.repetition_count if (single_run / quantum).denominator == 1 and single_run >= min_len: # it is enough to concatenate all children new_repetition_definition = program.repetition_definition program.repetition_count = 1 else: # we need to concatenate all children and unroll new_repetition_definition = 1 program.waveform = to_waveform(program.copy_tree_structure()) program.repetition_definition = new_repetition_definition program[:] = [] return else: for sub_program, comp_level in zip(program, comp_levels): if comp_level == _CompatibilityLevel.action_required: _make_compatible(sub_program, min_len, quantum, sample_rate)
[docs]def make_compatible(program: Loop, minimal_waveform_length: int, waveform_quantum: int, sample_rate: TimeType): """ check program for compatibility to AWG requirements, make it compatible if necessary and possible""" comp_level = _is_compatible(program, min_len=minimal_waveform_length, quantum=waveform_quantum, sample_rate=sample_rate) if comp_level == _CompatibilityLevel.incompatible_fraction: raise ValueError('The program duration in samples {} is not an integer'.format(program.duration * sample_rate)) if comp_level == _CompatibilityLevel.incompatible_too_short: raise ValueError('The program is too short to be a valid waveform. \n' ' program duration in samples: {} \n' ' minimal length: {}'.format(program.duration * sample_rate, minimal_waveform_length)) if comp_level == _CompatibilityLevel.incompatible_quantum: raise ValueError('The program duration in samples {} ' 'is not a multiple of quantum {}'.format(program.duration * sample_rate, waveform_quantum)) elif comp_level == _CompatibilityLevel.action_required: warnings.warn("qupulse will now concatenate waveforms to make the pulse/program compatible with the chosen AWG." " This might take some time. If you need this pulse more often it makes sense to write it in a " "way which is more AWG friendly.", MakeCompatibleWarning) _make_compatible(program, min_len=minimal_waveform_length, quantum=waveform_quantum, sample_rate=sample_rate) else: assert comp_level == _CompatibilityLevel.compatible
[docs]class MakeCompatibleWarning(ResourceWarning): pass
class VolatileModificationWarning(RuntimeWarning): """This warning is emitted if the colatile part of a program gets modified. This might imply that the volatile parameter cannot be change anymore.""" class DroppedMeasurementWarning(RuntimeWarning): """This warning is emitted if a measurement was dropped because there was no waveform attached."""