Source code for qupulse.program.waveforms

"""This module contains all waveform classes

Classes:
    - Waveform: An instantiated pulse which can be sampled to a raw voltage value array.
"""

import itertools
import operator
import warnings
from abc import ABCMeta, abstractmethod
from numbers import Real
from typing import (
    AbstractSet, Any, FrozenSet, Iterable, Mapping, NamedTuple, Sequence, Set,
    Tuple, Union, cast, Optional, List, Hashable)
from weakref import WeakValueDictionary, ref

import numpy as np

from qupulse import ChannelID
from qupulse.program.transformation import Transformation
from qupulse.utils import checked_int_cast, isclose
from qupulse.utils.types import TimeType, time_from_float
from qupulse.utils.performance import is_monotonic
from qupulse.comparable import Comparable
from qupulse.expressions import ExpressionScalar
from qupulse.pulses.interpolation import InterpolationStrategy
from qupulse.utils import checked_int_cast, isclose
from qupulse.utils.types import TimeType, time_from_float, FrozenDict
from qupulse.program.transformation import Transformation
from qupulse.utils import pairwise

[docs]class ConstantFunctionPulseTemplateWarning(UserWarning): """ This warning indicates a constant waveform is constructed from a FunctionPulseTemplate """ pass
__all__ = ["Waveform", "TableWaveform", "TableWaveformEntry", "FunctionWaveform", "SequenceWaveform", "MultiChannelWaveform", "RepetitionWaveform", "TransformingWaveform", "ArithmeticWaveform", "ConstantFunctionPulseTemplateWarning", "ConstantWaveform"] PULSE_TO_WAVEFORM_ERROR = None # error margin in pulse template to waveform conversion # these are private because there probably will be changes here _ALLOCATION_FUNCTION = np.full_like # pre_allocated = ALLOCATION_FUNCTION(sample_times, **ALLOCATION_FUNCTION_KWARGS) _ALLOCATION_FUNCTION_KWARGS = dict(fill_value=np.nan, dtype=float) def _to_time_type(duration: Real) -> TimeType: if isinstance(duration, TimeType): return duration else: return time_from_float(float(duration), absolute_error=PULSE_TO_WAVEFORM_ERROR)
[docs]class Waveform(Comparable, metaclass=ABCMeta): """Represents an instantiated PulseTemplate which can be sampled to retrieve arrays of voltage values for the hardware.""" __sampled_cache = WeakValueDictionary() __slots__ = ('_duration',) def __init__(self, duration: TimeType): self._duration = duration @property def duration(self) -> TimeType: """The duration of the waveform in time units.""" return self._duration
[docs] @abstractmethod def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: """Sample the waveform at given sample times. The unsafe means that there are no sanity checks performed. The provided sample times are assumed to be monotonously increasing and lie in the range of [0, waveform.duration] Args: sample_times: Times at which this Waveform will be sampled. output_array: Has to be either None or an array of the same size and type as sample_times. If not None, the sampled values will be written here and this array will be returned Result: The sampled values of this Waveform at the provided sample times. Has the same number of elements as sample_times. """
[docs] def get_sampled(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: """A wrapper to the unsafe_sample method which caches the result. This method enforces the constrains unsafe_sample expects and caches the result to save memory. Args: sample_times: Times at which this Waveform will be sampled. output_array: Has to be either None or an array of the same size and type as sample_times. If an array is given, the sampled values will be written into the given array and it will be returned. Otherwise, a new array will be created and cached to save memory. Result: The sampled values of this Waveform at the provided sample times. Is `output_array` if provided """ if len(sample_times) == 0: if output_array is None: return np.zeros_like(sample_times, dtype=float) elif len(output_array) == len(sample_times): return output_array else: raise ValueError('Output array length and sample time length are different') if not is_monotonic(sample_times): raise ValueError('The sample times are not monotonously increasing') if sample_times[0] < 0 or sample_times[-1] > float(self.duration): raise ValueError(f'The sample times [{sample_times[0]}, ..., {sample_times[-1]}] are not in the range' f' [0, duration={float(self.duration)}]') if channel not in self.defined_channels: raise KeyError('Channel not defined in this waveform: {}'.format(channel)) constant_value = self.constant_value(channel) if constant_value is None: if output_array is None: # cache the result to save memory result = self.unsafe_sample(channel, sample_times) result.flags.writeable = False key = hash(bytes(result)) if key not in self.__sampled_cache: self.__sampled_cache[key] = result return self.__sampled_cache[key] else: if len(output_array) != len(sample_times): raise ValueError('Output array length and sample time length are different') # use the user provided memory return self.unsafe_sample(channel=channel, sample_times=sample_times, output_array=output_array) else: if output_array is None: output_array = np.full_like(sample_times, fill_value=constant_value, dtype=float) else: output_array[:] = constant_value return output_array
@property @abstractmethod def defined_channels(self) -> AbstractSet[ChannelID]: """The channels this waveform should played on. Use :func:`~qupulse.pulses.instructions.get_measurement_windows` to get a waveform for a subset of these."""
[docs] @abstractmethod def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': """Unsafe version of :func:`~qupulse.pulses.instructions.get_measurement_windows`."""
[docs] def get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': """Get a waveform that only describes the channels contained in `channels`. Args: channels: A channel set the return value should confine to. Raises: KeyError: If `channels` is not a subset of the waveform's defined channels. Returns: A waveform with waveform.defined_channels == `channels` """ if not channels <= self.defined_channels: raise KeyError('Channels not defined on waveform: {}'.format(channels)) if channels == self.defined_channels: return self return self.unsafe_get_subset_for_channels(channels=channels)
[docs] def is_constant(self) -> bool: """Convenience function to check if all channels are constant. The result is equal to `all(waveform.constant_value(ch) is not None for ch in waveform.defined_channels)` but might be more performant. Returns: True if all channels have constant values. """ return self.constant_value_dict() is not None
[docs] def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: result = {ch: self.constant_value(ch) for ch in self.defined_channels} if None in result.values(): return None else: return result
[docs] def constant_value(self, channel: ChannelID) -> Optional[float]: """Checks if the requested channel has a constant value and returns it if so. Guarantee that this assertion passes for every t in waveform duration: >>> assert waveform.constant_value(channel) is None or waveform.constant_value(t) = waveform.get_sampled(channel, t) Args: channel: The channel to check Returns: None if there is no guarantee that the channel is constant. The value otherwise. """ return None
def __neg__(self): return FunctorWaveform.from_functor(self, {ch: np.negative for ch in self.defined_channels}) def __pos__(self): return self def _sort_key_for_channels(self) -> Sequence[Tuple[str, int]]: """Makes reproducible sorting by defined channels possible""" return sorted((ch, 0) if isinstance(ch, str) else ('', ch) for ch in self.defined_channels)
[docs] def reversed(self) -> 'Waveform': """Returns a reversed version of this waveform.""" # We don't check for constness here because const waveforms are supposed to override this method return ReversedWaveform(self)
[docs]class TableWaveformEntry(NamedTuple('TableWaveformEntry', [('t', Real), ('v', float), ('interp', InterpolationStrategy)])): def __init__(self, t: float, v: float, interp: InterpolationStrategy): if not callable(interp): raise TypeError('{} is neither callable nor of type InterpolationStrategy'.format(interp)) def __repr__(self): return f'{type(self).__name__}(t={self.t!r}, v={self.v!r}, interp={self.interp!r})'
[docs]class TableWaveform(Waveform): EntryInInit = Union[TableWaveformEntry, Tuple[float, float, InterpolationStrategy]] """Waveform obtained from instantiating a TablePulseTemplate.""" __slots__ = ('_table', '_channel_id')
[docs] def __init__(self, channel: ChannelID, waveform_table: Tuple[TableWaveformEntry, ...]) -> None: """Create a new TableWaveform instance. Args: waveform_table: A tuple of instantiated and validated table entries """ if not isinstance(waveform_table, tuple): warnings.warn("Please use a tuple of TableWaveformEntry to construct TableWaveform directly", category=DeprecationWarning) waveform_table = self._validate_input(waveform_table) super().__init__(duration=_to_time_type(waveform_table[-1].t)) self._table = waveform_table self._channel_id = channel
@staticmethod def _validate_input(input_waveform_table: Sequence[EntryInInit]) -> Union[Tuple[Real, Real], List[TableWaveformEntry]]: """ Checks that: - the time is increasing, - there are at least two entries Optimizations: - removes subsequent entries with same time or voltage values. - checks if the complete waveform is constant. Returns a (duration, value) tuple if this is the case Raises: ValueError: - there are less than two entries - the entries are not ordered in time - Any time is negative - The total length is zero Returns: A list of de-duplicated table entries OR A (duration, value) tuple if the waveform is constant """ # we use an iterator here to avoid duplicate work and be maximally efficient for short tables # We never use StopIteration to abort iteration. It always signifies an error. input_iter = iter(input_waveform_table) try: first_t, first_v, first_interp = next(input_iter) except StopIteration: raise ValueError("Waveform table mut not be empty") if first_t != 0.0: raise ValueError('First time entry is not zero.') previous_t = 0.0 previous_v = first_v output_waveform_table = [TableWaveformEntry(0.0, first_v, first_interp)] try: t, v, interp = next(input_iter) except StopIteration: raise ValueError("Waveform table has less than two entries.") if t < 0: raise ValueError('Negative time values are not allowed.') # constant_v is None <=> the waveform is constant until up to the current entry constant_v = interp.constant_value((previous_t, previous_v), (t, v)) for next_t, next_v, next_interp in input_iter: if next_t < t: if next_t < 0: raise ValueError('Negative time values are not allowed.') else: raise ValueError('Times are not increasing.') if constant_v is not None and interp.constant_value((t, v), (next_t, next_v)) != constant_v: constant_v = None if (previous_t != t or t != next_t) and (previous_v != v or v != next_v): # the time and the value differ both either from the next or the previous # otherwise we skip the entry previous_t = t previous_v = v output_waveform_table.append(TableWaveformEntry(t, v, interp)) t, v, interp = next_t, next_v, next_interp # Until now, we only checked that the time does not decrease. We require an increase because duration == 0 # waveforms are ill-formed. t is now the time of the last entry. if t == 0: raise ValueError('Last time entry is zero.') if constant_v is not None: # the waveform is constant return t, constant_v else: # we must still add the last entry to the table output_waveform_table.append(TableWaveformEntry(t, v, interp)) return output_waveform_table
[docs] def is_constant(self) -> bool: # only correct if `from_table` is used return False
[docs] def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: # only correct if `from_table` is used return None
[docs] @classmethod def from_table(cls, channel: ChannelID, table: Sequence[EntryInInit]) -> Union['TableWaveform', 'ConstantWaveform']: table = cls._validate_input(table) if isinstance(table, tuple): duration, amplitude = table return ConstantWaveform(duration=duration, amplitude=amplitude, channel=channel) else: return TableWaveform(channel, tuple(table))
@property def compare_key(self) -> Any: return self._channel_id, self._table
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: if output_array is None: output_array = _ALLOCATION_FUNCTION(sample_times, **_ALLOCATION_FUNCTION_KWARGS) if PULSE_TO_WAVEFORM_ERROR: # we need to replace the last entry's t with self.duration *entries, last = self._table entries.append(TableWaveformEntry(float(self.duration), last.v, last.interp)) else: entries = self._table for entry1, entry2 in pairwise(entries): indices = slice(sample_times.searchsorted(entry1.t, 'left'), sample_times.searchsorted(entry2.t, 'right')) output_array[indices] = \ entry2.interp((float(entry1.t), entry1.v), (float(entry2.t), entry2.v), sample_times[indices]) return output_array
@property def defined_channels(self) -> AbstractSet[ChannelID]: return {self._channel_id}
[docs] def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': return self
def __repr__(self): return f'{type(self).__name__}(channel={self._channel_id!r}, waveform_table={self._table!r})'
[docs]class ConstantWaveform(Waveform): # TODO: remove _is_constant_waveform = True __slots__ = ('_amplitude', '_channel')
[docs] def __init__(self, duration: Real, amplitude: Any, channel: ChannelID): """ Create a qupulse waveform corresponding to a ConstantPulseTemplate """ super().__init__(duration=_to_time_type(duration)) self._amplitude = amplitude self._channel = channel
[docs] @classmethod def from_mapping(cls, duration: Real, constant_values: Mapping[ChannelID, float]) -> Union['ConstantWaveform', 'MultiChannelWaveform']: """Construct a ConstantWaveform or a MultiChannelWaveform of ConstantWaveforms with given duration and values""" assert constant_values duration = _to_time_type(duration) if len(constant_values) == 1: (channel, amplitude), = constant_values.items() return cls(duration, amplitude=amplitude, channel=channel) else: return MultiChannelWaveform([cls(duration, amplitude=amplitude, channel=channel) for channel, amplitude in constant_values.items()])
[docs] def is_constant(self) -> bool: return True
[docs] def constant_value(self, channel: ChannelID) -> Optional[float]: assert channel == self._channel return self._amplitude
[docs] def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: return {self._channel: self._amplitude}
@property def defined_channels(self) -> AbstractSet[ChannelID]: """The channels this waveform should played on. Use :func:`~qupulse.pulses.instructions.get_measurement_windows` to get a waveform for a subset of these.""" return {self._channel} @property def compare_key(self) -> Tuple[Any, ...]: return self._duration, self._amplitude, self._channel
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: if output_array is None: return np.full_like(sample_times, fill_value=self._amplitude, dtype=float) else: output_array[:] = self._amplitude return output_array
[docs] def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: """Unsafe version of :func:`~qupulse.pulses.instructions.get_measurement_windows`.""" return self
def __repr__(self): return f"{type(self).__name__}(duration={self.duration!r}, "\ f"amplitude={self._amplitude!r}, channel={self._channel!r})"
[docs] def reversed(self) -> 'Waveform': return self
[docs]class FunctionWaveform(Waveform): """Waveform obtained from instantiating a FunctionPulseTemplate.""" __slots__ = ('_expression', '_channel_id')
[docs] def __init__(self, expression: ExpressionScalar, duration: float, channel: ChannelID) -> None: """Creates a new FunctionWaveform instance. Args: expression: The function represented by this FunctionWaveform as a mathematical expression where 't' denotes the time variable. It must not have other variables duration: The duration of the waveform measurement_windows: A list of measurement windows channel: The channel this waveform is played on """ if set(expression.variables) - set('t'): raise ValueError('FunctionWaveforms may not depend on anything but "t"') elif not expression.variables: warnings.warn("Constant FunctionWaveform is not recommended as the constant propagation will be suboptimal", category=ConstantFunctionPulseTemplateWarning) super().__init__(duration=_to_time_type(duration)) self._expression = expression self._channel_id = channel
[docs] @classmethod def from_expression(cls, expression: ExpressionScalar, duration: float, channel: ChannelID) -> Union['FunctionWaveform', ConstantWaveform]: if expression.variables: return cls(expression, duration, channel) else: return ConstantWaveform(amplitude=expression.evaluate_numeric(), duration=duration, channel=channel)
[docs] def is_constant(self) -> bool: # only correct if `from_expression` is used return False
[docs] def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: # only correct if `from_expression` is used return None
@property def defined_channels(self) -> AbstractSet[ChannelID]: return {self._channel_id} @property def compare_key(self) -> Any: return self._channel_id, self._expression, self._duration @property def duration(self) -> TimeType: return self._duration
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: evaluated = self._expression.evaluate_numeric(t=sample_times) if output_array is None: if self._expression.variables: return evaluated.astype(float) else: return np.full_like(sample_times, fill_value=float(evaluated), dtype=float) else: output_array[:] = evaluated return output_array
[docs] def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform: return self
def __repr__(self): return f"{type(self).__name__}(duration={self.duration!r}, "\ f"expression={self._expression!r}, channel={self._channel_id!r})"
[docs]class SequenceWaveform(Waveform): """This class allows putting multiple PulseTemplate together in one waveform on the hardware.""" __slots__ = ('_sequenced_waveforms', )
[docs] def __init__(self, sub_waveforms: Iterable[Waveform]): """Use Waveform.from_sequence for optimal construction :param subwaveforms: All waveforms must have the same defined channels """ if not sub_waveforms: raise ValueError( "SequenceWaveform cannot be constructed without channel waveforms." ) # do not fail on iterators although we do not allow them as an argument sequenced_waveforms = tuple(sub_waveforms) super().__init__(duration=sum(waveform.duration for waveform in sequenced_waveforms)) self._sequenced_waveforms = sequenced_waveforms defined_channels = self._sequenced_waveforms[0].defined_channels if not all(waveform.defined_channels == defined_channels for waveform in itertools.islice(self._sequenced_waveforms, 1, None)): for waveform in self._sequenced_waveforms[1:]: if not waveform.defined_channels == self.defined_channels: print(f"SequenceWaveform: defined channels {self.defined_channels} do not match {waveform.defined_channels} ") raise ValueError( "SequenceWaveform cannot be constructed from waveforms of different" "defined channels." )
[docs] @classmethod def from_sequence(cls, waveforms: Sequence['Waveform']) -> 'Waveform': """Returns a waveform the represents the given sequence of waveforms. Applies some optimizations.""" assert waveforms, "Sequence must not be empty" if len(waveforms) == 1: return waveforms[0] flattened = [] constant_values = waveforms[0].constant_value_dict() for wf in waveforms: if constant_values and constant_values != wf.constant_value_dict(): constant_values = None if isinstance(wf, cls): flattened.extend(wf.sequenced_waveforms) else: flattened.append(wf) if constant_values is None: return cls(sub_waveforms=flattened) else: duration = sum(wf.duration for wf in flattened) return ConstantWaveform.from_mapping(duration, constant_values)
[docs] def is_constant(self) -> bool: # only correct if from_sequence is used for construction return False
[docs] def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: # only correct if from_sequence is used for construction return None
[docs] def constant_value(self, channel: ChannelID) -> Optional[float]: v = None for wf in self._sequenced_waveforms: wf_cv = wf.constant_value(channel) if wf_cv is None: return None elif wf_cv == v: continue elif v is None: v = wf_cv else: assert v != wf_cv return None return v
@property def defined_channels(self) -> AbstractSet[ChannelID]: return self._sequenced_waveforms[0].defined_channels
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: if output_array is None: output_array = _ALLOCATION_FUNCTION(sample_times, **_ALLOCATION_FUNCTION_KWARGS) time = 0 for subwaveform in self._sequenced_waveforms: # before you change anything here, make sure to understand the difference between basic and advanced # indexing in numpy and their copy/reference behaviour end = time + subwaveform.duration indices = slice(*sample_times.searchsorted((float(time), float(end)), 'left')) subwaveform.unsafe_sample(channel=channel, sample_times=sample_times[indices]-np.float64(time), output_array=output_array[indices]) time = end return output_array
@property def compare_key(self) -> Tuple[Waveform]: return self._sequenced_waveforms @property def duration(self) -> TimeType: return self._duration
[docs] def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': return SequenceWaveform.from_sequence([ sub_waveform.unsafe_get_subset_for_channels(channels & sub_waveform.defined_channels) for sub_waveform in self._sequenced_waveforms if sub_waveform.defined_channels & channels])
@property def sequenced_waveforms(self) -> Sequence[Waveform]: return self._sequenced_waveforms def __repr__(self): return f"{type(self).__name__}({self._sequenced_waveforms})"
[docs]class MultiChannelWaveform(Waveform): """A MultiChannelWaveform is a Waveform object that allows combining arbitrary Waveform objects to into a single waveform defined for several channels. The number of channels used by the MultiChannelWaveform object is the sum of the channels used by the Waveform objects it consists of. MultiChannelWaveform allows an arbitrary mapping of channels defined by the Waveforms it consists of and the channels it defines. For example, if the MultiChannelWaveform consists of a two Waveform objects A and B which define two channels each, then the channels of the MultiChannelWaveform may be 0: A.1, 1: B.0, 2: B.1, 3: A.0 where A.0 means channel 0 of Waveform object A. The following constraints must hold: - The durations of all Waveform objects must be equal. - The channel mapping must be sane, i.e., no channel of the MultiChannelWaveform must be assigned more than one channel of any Waveform object it consists of """ __slots__ = ('_sub_waveforms', '_defined_channels')
[docs] def __init__(self, sub_waveforms: List[Waveform]) -> None: """Create a new MultiChannelWaveform instance. Use `MultiChannelWaveform.from_parallel` for optimal construction. Requires a list of subwaveforms in the form (Waveform, List(int)) where the list defines the channel mapping, i.e., a value y at index x in the list means that channel x of the subwaveform will be mapped to channel y of this MultiChannelWaveform object. Args: sub_waveforms: The list of sub waveforms of this MultiChannelWaveform. List might get sorted! Raises: ValueError, if a channel mapping is out of bounds of the channels defined by this MultiChannelWaveform ValueError, if several subwaveform channels are assigned to a single channel of this MultiChannelWaveform ValueError, if subwaveforms have inconsistent durations """ if not sub_waveforms: raise ValueError( "MultiChannelWaveform cannot be constructed without channel waveforms." ) # sort the waveforms with their defined channels to make compare key reproducible if not isinstance(sub_waveforms, list): sub_waveforms = list(sub_waveforms) sub_waveforms.sort(key=lambda wf: wf._sort_key_for_channels()) super().__init__(duration=sub_waveforms[0].duration) self._sub_waveforms = tuple(sub_waveforms) defined_channels = set() for waveform in self._sub_waveforms: if waveform.defined_channels & defined_channels: raise ValueError('Channel may not be defined in multiple waveforms', waveform.defined_channels & defined_channels) defined_channels |= waveform.defined_channels self._defined_channels = frozenset(defined_channels) if not all(isclose(waveform.duration, self.duration) for waveform in self._sub_waveforms[1:]): # meaningful error message: durations = {} for waveform in self._sub_waveforms: for duration, channels in durations.items(): if isclose(waveform.duration, duration): channels.update(waveform.defined_channels) break else: durations[waveform.duration] = set(waveform.defined_channels) raise ValueError( "MultiChannelWaveform cannot be constructed from channel waveforms of different durations.", durations )
[docs] @staticmethod def from_parallel(waveforms: Sequence[Waveform]) -> Waveform: assert waveforms, "ARgument must not be empty" if len(waveforms) == 1: return waveforms[0] # we do not look at constant values here because there is no benefit. We would need to construct a new # MultiChannelWaveform anyways # avoid unnecessary multi channel nesting flattened = [] for waveform in waveforms: if isinstance(waveform, MultiChannelWaveform): flattened.extend(waveform._sub_waveforms) else: flattened.append(waveform) return MultiChannelWaveform(flattened)
[docs] def is_constant(self) -> bool: return all(wf.is_constant() for wf in self._sub_waveforms)
[docs] def constant_value(self, channel: ChannelID) -> Optional[float]: return self[channel].constant_value(channel)
[docs] def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: d = {} for wf in self._sub_waveforms: wf_d = wf.constant_value_dict() if wf_d is None: return None else: d.update(wf_d) return d
@property def duration(self) -> TimeType: return self._sub_waveforms[0].duration def __getitem__(self, key: ChannelID) -> Waveform: for waveform in self._sub_waveforms: if key in waveform.defined_channels: return waveform raise KeyError('Unknown channel ID: {}'.format(key), key) @property def defined_channels(self) -> AbstractSet[ChannelID]: return self._defined_channels @property def compare_key(self) -> Any: # sort with channels return self._sub_waveforms
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: return self[channel].unsafe_sample(channel, sample_times, output_array)
[docs] def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': relevant_sub_waveforms = [swf for swf in self._sub_waveforms if swf.defined_channels & channels] if len(relevant_sub_waveforms) == 1: return relevant_sub_waveforms[0].get_subset_for_channels(channels) elif len(relevant_sub_waveforms) > 1: return MultiChannelWaveform.from_parallel( [sub_waveform.get_subset_for_channels(channels & sub_waveform.defined_channels) for sub_waveform in relevant_sub_waveforms]) else: raise KeyError('Unknown channels: {}'.format(channels))
def __repr__(self): return f"{type(self).__name__}({self._sub_waveforms!r})"
[docs]class RepetitionWaveform(Waveform): """This class allows putting multiple PulseTemplate together in one waveform on the hardware.""" __slots__ = ('_body', '_repetition_count') def __init__(self, body: Waveform, repetition_count: int): repetition_count = checked_int_cast(repetition_count) if repetition_count < 1 or not isinstance(repetition_count, int): raise ValueError('Repetition count must be an integer >0') super().__init__(duration=body.duration * repetition_count) self._body = body self._repetition_count = repetition_count
[docs] @classmethod def from_repetition_count(cls, body: Waveform, repetition_count: int) -> Waveform: constant_values = body.constant_value_dict() if constant_values is None: return RepetitionWaveform(body, repetition_count) else: return ConstantWaveform.from_mapping(body.duration * repetition_count, constant_values)
@property def defined_channels(self) -> AbstractSet[ChannelID]: return self._body.defined_channels
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: if output_array is None: output_array = _ALLOCATION_FUNCTION(sample_times, **_ALLOCATION_FUNCTION_KWARGS) body_duration = self._body.duration time = 0 for _ in range(self._repetition_count): end = time + body_duration indices = slice(*sample_times.searchsorted((float(time), float(end)), 'left')) self._body.unsafe_sample(channel=channel, sample_times=sample_times[indices] - float(time), output_array=output_array[indices]) time = end return output_array
@property def compare_key(self) -> Tuple[Any, int]: return self._body.compare_key, self._repetition_count
[docs] def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform: return RepetitionWaveform.from_repetition_count( body=self._body.unsafe_get_subset_for_channels(channels), repetition_count=self._repetition_count)
[docs] def is_constant(self) -> bool: return self._body.is_constant()
[docs] def constant_value(self, channel: ChannelID) -> Optional[float]: return self._body.constant_value(channel)
[docs] def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: return self._body.constant_value_dict()
def __repr__(self): return f"{type(self).__name__}(body={self._body!r}, repetition_count={self._repetition_count!r})"
[docs]class TransformingWaveform(Waveform): __slots__ = ('_inner_waveform', '_transformation', '_cached_data', '_cached_times') def __init__(self, inner_waveform: Waveform, transformation: Transformation): """""" super(TransformingWaveform, self).__init__(duration=inner_waveform.duration) self._inner_waveform = inner_waveform self._transformation = transformation # cache data of inner channels based identified and invalidated by the sample times self._cached_data = None self._cached_times = lambda: None
[docs] @classmethod def from_transformation(cls, inner_waveform: Waveform, transformation: Transformation) -> Waveform: constant_values = inner_waveform.constant_value_dict() if constant_values is None or not transformation.is_constant_invariant(): return cls(inner_waveform, transformation) transformed_constant_values = {key: float(value) for key, value in transformation(0., constant_values).items()} return ConstantWaveform.from_mapping(inner_waveform.duration, transformed_constant_values)
[docs] def is_constant(self) -> bool: # only true if `from_transformation` was used return False
[docs] def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: # only true if `from_transformation` was used return None
[docs] def constant_value(self, channel: ChannelID) -> Optional[float]: if not self._transformation.is_constant_invariant(): return None in_channels = self._transformation.get_input_channels({channel}) in_values = {ch: self._inner_waveform.constant_value(ch) for ch in in_channels} if any(val is None for val in in_values.values()): return None else: return self._transformation(0., in_values)[channel]
@property def inner_waveform(self) -> Waveform: return self._inner_waveform @property def transformation(self) -> Transformation: return self._transformation @property def defined_channels(self) -> AbstractSet[ChannelID]: return self.transformation.get_output_channels(self.inner_waveform.defined_channels) @property def compare_key(self) -> Tuple[Waveform, Transformation]: return self.inner_waveform, self.transformation
[docs] def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'SubsetWaveform': return SubsetWaveform(self, channel_subset=channels)
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: if self._cached_times() is not sample_times: self._cached_data = dict() self._cached_times = ref(sample_times) if channel not in self._cached_data: inner_channels = self.transformation.get_input_channels({channel}) inner_data = {inner_channel: self.inner_waveform.unsafe_sample(inner_channel, sample_times) for inner_channel in inner_channels} outer_data = self.transformation(sample_times, inner_data) self._cached_data.update(outer_data) if output_array is None: output_array = self._cached_data[channel] else: output_array[:] = self._cached_data[channel] return output_array
class SubsetWaveform(Waveform): __slots__ = ('_inner_waveform', '_channel_subset') def __init__(self, inner_waveform: Waveform, channel_subset: Set[ChannelID]): super().__init__(duration=inner_waveform.duration) self._inner_waveform = inner_waveform self._channel_subset = frozenset(channel_subset) @property def inner_waveform(self) -> Waveform: return self._inner_waveform @property def defined_channels(self) -> FrozenSet[ChannelID]: return self._channel_subset @property def compare_key(self) -> Tuple[frozenset, Waveform]: return self.defined_channels, self.inner_waveform def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: return self.inner_waveform.get_subset_for_channels(channels) def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: return self.inner_waveform.unsafe_sample(channel, sample_times, output_array) def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: d = self._inner_waveform.constant_value_dict() if d is not None: return {ch: d[ch] for ch in self._channel_subset} def constant_value(self, channel: ChannelID) -> Optional[float]: if channel not in self._channel_subset: raise KeyError(channel) return self._inner_waveform.constant_value(channel)
[docs]class ArithmeticWaveform(Waveform): """Channels only present in one waveform have the operations neutral element on the other.""" numpy_operator_map = {'+': np.add, '-': np.subtract} operator_map = {'+': operator.add, '-': operator.sub} rhs_only_map = {'+': operator.pos, '-': operator.neg} numpy_rhs_only_map = {'+': np.positive, '-': np.negative} __slots__ = ('_lhs', '_rhs', '_arithmetic_operator') def __init__(self, lhs: Waveform, arithmetic_operator: str, rhs: Waveform): super().__init__(duration=lhs.duration) self._lhs = lhs self._rhs = rhs self._arithmetic_operator = arithmetic_operator assert np.isclose(float(self._lhs.duration), float(self._rhs.duration)) assert arithmetic_operator in self.operator_map
[docs] @classmethod def from_operator(cls, lhs: Waveform, arithmetic_operator: str, rhs: Waveform): # one could optimize rhs_cv to being only created if lhs_cv is not None but this makes the code harder to read lhs_cv = lhs.constant_value_dict() rhs_cv = rhs.constant_value_dict() if lhs_cv is None or rhs_cv is None: return cls(lhs, arithmetic_operator, rhs) else: constant_values = dict(lhs_cv) op = cls.operator_map[arithmetic_operator] rhs_op = cls.rhs_only_map[arithmetic_operator] for ch, rhs_val in rhs_cv.items(): if ch in constant_values: constant_values[ch] = op(constant_values[ch], rhs_val) else: constant_values[ch] = rhs_op(rhs_val) duration = lhs.duration assert isclose(duration, rhs.duration) return ConstantWaveform.from_mapping(duration, constant_values)
[docs] def constant_value(self, channel: ChannelID) -> Optional[float]: if channel not in self._rhs.defined_channels: return self._lhs.constant_value(channel) rhs = self._rhs.constant_value(channel) if rhs is None: return None if channel in self._lhs.defined_channels: lhs = self._lhs.constant_value(channel) if lhs is None: return None return self.operator_map[self._arithmetic_operator](lhs, rhs) else: return self.rhs_only_map[self._arithmetic_operator](rhs)
[docs] def is_constant(self) -> bool: # only correct if from_operator is used return False
[docs] def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: # only correct if from_operator is used return None
@property def lhs(self) -> Waveform: return self._lhs @property def rhs(self) -> Waveform: return self._rhs @property def arithmetic_operator(self) -> str: return self._arithmetic_operator @property def duration(self) -> TimeType: return self._lhs.duration @property def defined_channels(self) -> AbstractSet[ChannelID]: return self._lhs.defined_channels | self._rhs.defined_channels
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: if channel in self._lhs.defined_channels: lhs = self._lhs.unsafe_sample(channel=channel, sample_times=sample_times, output_array=output_array) else: lhs = None if channel in self._rhs.defined_channels: rhs = self._rhs.unsafe_sample(channel=channel, sample_times=sample_times, output_array=None if lhs is not None else output_array) else: rhs = None if rhs is not None and lhs is not None: arithmetic_operator = self.numpy_operator_map[self._arithmetic_operator] if output_array is None: output_array = lhs return arithmetic_operator(lhs, rhs, out=output_array) else: if lhs is None: assert rhs is not None, "channel %r not in defined channels (internal bug)" % channel return self.numpy_rhs_only_map[self._arithmetic_operator](rhs, out=output_array) else: return lhs
[docs] def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: # TODO: optimization possible return SubsetWaveform(self, channels)
@property def compare_key(self) -> Tuple[str, Waveform, Waveform]: return self._arithmetic_operator, self._lhs, self._rhs
class FunctorWaveform(Waveform): # TODO: Use Protocol to enforce that it accepts second argument has the keyword out Functor = callable __slots__ = ('_inner_waveform', '_functor') """Apply a channel wise functor that works inplace to all results. The functor must accept two arguments""" def __init__(self, inner_waveform: Waveform, functor: Mapping[ChannelID, Functor]): super(FunctorWaveform, self).__init__(duration=inner_waveform.duration) self._inner_waveform = inner_waveform self._functor = dict(functor.items()) assert set(functor.keys()) == inner_waveform.defined_channels, ("There is no default identity mapping (yet)." "File an issue on github if you need it.") @classmethod def from_functor(cls, inner_waveform: Waveform, functor: Mapping[ChannelID, Functor]): constant_values = inner_waveform.constant_value_dict() if constant_values is None: return FunctorWaveform(inner_waveform, functor) funced_constant_values = {ch: functor[ch](val) for ch, val in constant_values.items()} return ConstantWaveform.from_mapping(inner_waveform.duration, funced_constant_values) def is_constant(self) -> bool: # only correct if `from_functor` was used return False def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: # only correct if `from_functor` was used return None def constant_value(self, channel: ChannelID) -> Optional[float]: inner = self._inner_waveform.constant_value(channel) if inner is None: return None else: return self._functor[channel](inner) @property def defined_channels(self) -> AbstractSet[ChannelID]: return self._inner_waveform.defined_channels def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: inner_output = self._inner_waveform.unsafe_sample(channel, sample_times, output_array) return self._functor[channel](inner_output, out=inner_output) def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: return FunctorWaveform.from_functor( self._inner_waveform.unsafe_get_subset_for_channels(channels), {ch: self._functor[ch] for ch in channels}) @property def compare_key(self) -> Tuple[Waveform, FrozenSet]: return self._inner_waveform, frozenset(self._functor.items()) class ReversedWaveform(Waveform): """Reverses the inner waveform in time.""" __slots__ = ('_inner',) def __init__(self, inner: Waveform): super().__init__(duration=inner.duration) self._inner = inner @classmethod def from_to_reverse(cls, inner: Waveform) -> Waveform: if inner.constant_value_dict(): return inner else: return cls(inner) def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: inner_sample_times = (float(self.duration) - sample_times)[::-1] if output_array is None: return self._inner.unsafe_sample(channel, inner_sample_times, None)[::-1] else: inner_output_array = output_array[::-1] inner_output_array = self._inner.unsafe_sample(channel, inner_sample_times, output_array=inner_output_array) if inner_output_array.base not in (output_array, output_array.base): # TODO: is there a guarantee by numpy we never end up here? output_array[:] = inner_output_array[::-1] return output_array @property def defined_channels(self) -> AbstractSet[ChannelID]: return self._inner.defined_channels def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': return ReversedWaveform.from_to_reverse(self._inner.unsafe_get_subset_for_channels(channels)) @property def compare_key(self) -> Hashable: return self._inner.compare_key def reversed(self) -> 'Waveform': return self._inner