Source code for qupulse._program.waveforms

"""This module contains all waveform classes

Classes:
    - Waveform: An instantiated pulse which can be sampled to a raw voltage value array.
"""

import itertools
from abc import ABCMeta, abstractmethod
from weakref import WeakValueDictionary, ref
from typing import Union, Set, Sequence, NamedTuple, Tuple, Any, Iterable, FrozenSet, Optional

import numpy as np

from qupulse import ChannelID
from qupulse.utils import checked_int_cast, isclose
from qupulse.utils.types import TimeType, time_from_float
from qupulse.comparable import Comparable
from qupulse.expressions import ExpressionScalar
from qupulse.pulses.interpolation import InterpolationStrategy
from qupulse._program.transformation import Transformation


__all__ = ["Waveform", "TableWaveform", "TableWaveformEntry", "FunctionWaveform", "SequenceWaveform",
           "MultiChannelWaveform", "RepetitionWaveform", "TransformingWaveform"]


[docs]class Waveform(Comparable, metaclass=ABCMeta): """Represents an instantiated PulseTemplate which can be sampled to retrieve arrays of voltage values for the hardware.""" __sampled_cache = WeakValueDictionary() @property @abstractmethod def duration(self) -> TimeType: """The duration of the waveform in time units."""
[docs] @abstractmethod def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: """Sample the waveform at given sample times. The unsafe means that there are no sanity checks performed. The provided sample times are assumed to be monotonously increasing and lie in the range of [0, waveform.duration] Args: sample_times: Times at which this Waveform will be sampled. output_array: Has to be either None or an array of the same size and type as sample_times. If not None, the sampled values will be written here and this array will be returned Result: The sampled values of this Waveform at the provided sample times. Has the same number of elements as sample_times. """
[docs] def get_sampled(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: """A wrapper to the unsafe_sample method which caches the result. This method enforces the constrains unsafe_sample expects and caches the result to save memory. Args/Result: sample_times: Times at which this Waveform will be sampled. output_array: Has to be either None or an array of the same size and type as sample_times. If an array is given, the sampled values will be written into the given array and it will be returned. Otherwise, a new array will be created and cached to save memory. Result: The sampled values of this Waveform at the provided sample times. """ if len(sample_times) == 0: if output_array is None: return np.zeros_like(sample_times) elif len(output_array) == len(sample_times): return output_array else: raise ValueError('Output array length and sample time length are different') if np.any(sample_times[:-1] >= sample_times[1:]): raise ValueError('The sample times are not monotonously increasing') if sample_times[0] < 0 or sample_times[-1] > self.duration: raise ValueError('The sample times are not in the range [0, duration]') if channel not in self.defined_channels: raise KeyError('Channel not defined in this waveform: {}'.format(channel)) if output_array is None: # cache the result to save memory result = self.unsafe_sample(channel, sample_times) result.flags.writeable = False key = hash(bytes(result)) if key not in self.__sampled_cache: self.__sampled_cache[key] = result return self.__sampled_cache[key] else: if len(output_array) != len(sample_times): raise ValueError('Output array length and sample time length are different') # use the user provided memory return self.unsafe_sample(channel=channel, sample_times=sample_times, output_array=output_array)
@property @abstractmethod def defined_channels(self) -> Set[ChannelID]: """The channels this waveform should played on. Use :func:`~qupulse.pulses.instructions.get_measurement_windows` to get a waveform for a subset of these."""
[docs] @abstractmethod def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'Waveform': """Unsafe version of :func:`~qupulse.pulses.instructions.get_measurement_windows`."""
[docs] def get_subset_for_channels(self, channels: Set[ChannelID]) -> 'Waveform': """Get a waveform that only describes the channels contained in `channels`. Args: channels: A channel set the return value should confine to. Raises: KeyError: If `channels` is not a subset of the waveform's defined channels. Returns: A waveform with waveform.defined_channels == `channels` """ if not channels <= self.defined_channels: raise KeyError('Channels not defined on waveform: {}'.format(channels)) if channels == self.defined_channels: return self return self.unsafe_get_subset_for_channels(channels=channels)
[docs]class TableWaveformEntry(NamedTuple('TableWaveformEntry', [('t', float), ('v', float), ('interp', InterpolationStrategy)])): def __init__(self, t: float, v: float, interp: InterpolationStrategy): if not callable(interp): raise TypeError('{} is neither callable nor of type InterpolationStrategy'.format(interp))
[docs]class TableWaveform(Waveform): EntryInInit = Union[TableWaveformEntry, Tuple[float, float, InterpolationStrategy]] """Waveform obtained from instantiating a TablePulseTemplate."""
[docs] def __init__(self, channel: ChannelID, waveform_table: Sequence[EntryInInit]) -> None: """Create a new TableWaveform instance. Args: waveform_table (ImmutableList(WaveformTableEntry)): A list of instantiated table entries of the form (time as float, voltage as float, interpolation strategy). """ super().__init__() self._table = self._validate_input(waveform_table) self._channel_id = channel
@staticmethod def _validate_input(input_waveform_table: Sequence[EntryInInit]) -> Tuple[TableWaveformEntry, ...]: """ Checks that: - the time is increasing, - there are at least two entries and removes subsequent entries with same time or voltage values. :param input_waveform_table: :return: """ if len(input_waveform_table) < 2: raise ValueError("Waveform table has less than two entries.") if input_waveform_table[0][0] != 0: raise ValueError('First time entry is not zero.') if input_waveform_table[-1][0] == 0: raise ValueError('Last time entry is zero.') output_waveform_table = [] previous_t = 0 previous_v = None for (t, v, interp), (next_t, next_v, _) in itertools.zip_longest(input_waveform_table, input_waveform_table[1:], fillvalue=(float('inf'), None, None)): if next_t < t: if next_t < 0: raise ValueError('Negative time values are not allowed.') else: raise ValueError('Times are not increasing.') if (previous_t != t or t != next_t) and (previous_v != v or v != next_v): previous_t = t previous_v = v output_waveform_table.append(TableWaveformEntry(t, v, interp)) return tuple(output_waveform_table) @property def compare_key(self) -> Any: return self._channel_id, self._table @property def duration(self) -> TimeType: return time_from_float(self._table[-1].t)
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: if output_array is None: output_array = np.empty_like(sample_times) for entry1, entry2 in zip(self._table[:-1], self._table[1:]): indices = slice(np.searchsorted(sample_times, entry1.t, 'left'), np.searchsorted(sample_times, entry2.t, 'right')) output_array[indices] = \ entry2.interp((entry1.t, entry1.v), (entry2.t, entry2.v), sample_times[indices]) return output_array
@property def defined_channels(self) -> Set[ChannelID]: return {self._channel_id}
[docs] def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'Waveform': return self
[docs]class FunctionWaveform(Waveform): """Waveform obtained from instantiating a FunctionPulseTemplate."""
[docs] def __init__(self, expression: ExpressionScalar, duration: float, channel: ChannelID) -> None: """Creates a new FunctionWaveform instance. Args: expression: The function represented by this FunctionWaveform as a mathematical expression where 't' denotes the time variable. It must not have other variables duration: The duration of the waveform measurement_windows: A list of measurement windows channel: The channel this waveform is played on """ super().__init__() if set(expression.variables) - set('t'): raise ValueError('FunctionWaveforms may not depend on anything but "t"') self._expression = expression self._duration = time_from_float(duration) self._channel_id = channel
@property def defined_channels(self) -> Set[ChannelID]: return {self._channel_id} @property def compare_key(self) -> Any: return self._channel_id, self._expression, self._duration @property def duration(self) -> TimeType: return self._duration
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: if output_array is None: output_array = np.empty(len(sample_times)) output_array[:] = self._expression.evaluate_numeric(t=sample_times) return output_array
[docs] def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: return self
[docs]class SequenceWaveform(Waveform): """This class allows putting multiple PulseTemplate together in one waveform on the hardware."""
[docs] def __init__(self, sub_waveforms: Iterable[Waveform]): """ :param subwaveforms: All waveforms must have the same defined channels """ if not sub_waveforms: raise ValueError( "SequenceWaveform cannot be constructed without channel waveforms." ) def flattened_sub_waveforms() -> Iterable[Waveform]: for sub_waveform in sub_waveforms: if isinstance(sub_waveform, SequenceWaveform): yield from sub_waveform._sequenced_waveforms else: yield sub_waveform self._sequenced_waveforms = tuple(flattened_sub_waveforms()) self._duration = sum(waveform.duration for waveform in self._sequenced_waveforms) if not all(waveform.defined_channels == self.defined_channels for waveform in self._sequenced_waveforms[1:]): raise ValueError( "SequenceWaveform cannot be constructed from waveforms of different" "defined channels." )
@property def defined_channels(self) -> Set[ChannelID]: return self._sequenced_waveforms[0].defined_channels
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: if output_array is None: output_array = np.empty_like(sample_times) time = 0 for subwaveform in self._sequenced_waveforms: # before you change anything here, make sure to understand the difference between basic and advanced # indexing in numpy and their copy/reference behaviour end = time + subwaveform.duration indices = slice(*np.searchsorted(sample_times, (float(time), float(end)), 'left')) subwaveform.unsafe_sample(channel=channel, sample_times=sample_times[indices]-np.float64(time), output_array=output_array[indices]) time = end return output_array
@property def compare_key(self) -> Tuple[Waveform]: return self._sequenced_waveforms @property def duration(self) -> TimeType: return self._duration
[docs] def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'Waveform': return SequenceWaveform( sub_waveform.unsafe_get_subset_for_channels(channels & sub_waveform.defined_channels) for sub_waveform in self._sequenced_waveforms if sub_waveform.defined_channels & channels)
[docs]class MultiChannelWaveform(Waveform): """A MultiChannelWaveform is a Waveform object that allows combining arbitrary Waveform objects to into a single waveform defined for several channels. The number of channels used by the MultiChannelWaveform object is the sum of the channels used by the Waveform objects it consists of. MultiChannelWaveform allows an arbitrary mapping of channels defined by the Waveforms it consists of and the channels it defines. For example, if the MultiChannelWaveform consists of a two Waveform objects A and B which define two channels each, then the channels of the MultiChannelWaveform may be 0: A.1, 1: B.0, 2: B.1, 3: A.0 where A.0 means channel 0 of Waveform object A. The following constraints must hold: - The durations of all Waveform objects must be equal. - The channel mapping must be sane, i.e., no channel of the MultiChannelWaveform must be assigned more than one channel of any Waveform object it consists of """
[docs] def __init__(self, sub_waveforms: Iterable[Waveform]) -> None: """Create a new MultiChannelWaveform instance. Requires a list of subwaveforms in the form (Waveform, List(int)) where the list defines the channel mapping, i.e., a value y at index x in the list means that channel x of the subwaveform will be mapped to channel y of this MultiChannelWaveform object. Args: sub_waveforms (Iterable( Waveform )): The list of sub waveforms of this MultiChannelWaveform Raises: ValueError, if a channel mapping is out of bounds of the channels defined by this MultiChannelWaveform ValueError, if several subwaveform channels are assigned to a single channel of this MultiChannelWaveform ValueError, if subwaveforms have inconsistent durations """ super().__init__() if not sub_waveforms: raise ValueError( "MultiChannelWaveform cannot be constructed without channel waveforms." ) # avoid unnecessary multi channel nesting def flatten_sub_waveforms(to_flatten): for sub_waveform in to_flatten: if isinstance(sub_waveform, MultiChannelWaveform): yield from sub_waveform._sub_waveforms else: yield sub_waveform # sort the waveforms with their defined channels to make compare key reproducible def get_sub_waveform_sort_key(waveform): return tuple(sorted(tuple('{}_stringified_numeric_channel'.format(ch) if isinstance(ch, int) else ch for ch in waveform.defined_channels))) self._sub_waveforms = tuple(sorted(flatten_sub_waveforms(sub_waveforms), key=get_sub_waveform_sort_key)) self.__defined_channels = set() for waveform in self._sub_waveforms: if waveform.defined_channels & self.__defined_channels: raise ValueError('Channel may not be defined in multiple waveforms', waveform.defined_channels & self.__defined_channels) self.__defined_channels |= waveform.defined_channels if not all(isclose(waveform.duration, self._sub_waveforms[0].duration) for waveform in self._sub_waveforms[1:]): # meaningful error message: durations = {} for waveform in self._sub_waveforms: for duration, channels in durations.items(): if isclose(waveform.duration, duration): channels.update(waveform.defined_channels) break else: durations[waveform.duration] = set(waveform.defined_channels) raise ValueError( "MultiChannelWaveform cannot be constructed from channel waveforms of different durations.", durations )
@property def duration(self) -> TimeType: return self._sub_waveforms[0].duration def __getitem__(self, key: ChannelID) -> Waveform: for waveform in self._sub_waveforms: if key in waveform.defined_channels: return waveform raise KeyError('Unknown channel ID: {}'.format(key), key) @property def defined_channels(self) -> Set[ChannelID]: return self.__defined_channels @property def compare_key(self) -> Any: # sort with channels return self._sub_waveforms
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: return self[channel].unsafe_sample(channel, sample_times, output_array)
[docs] def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'Waveform': relevant_sub_waveforms = tuple(swf for swf in self._sub_waveforms if swf.defined_channels & channels) if len(relevant_sub_waveforms) == 1: return relevant_sub_waveforms[0].get_subset_for_channels(channels) elif len(relevant_sub_waveforms) > 1: return MultiChannelWaveform( sub_waveform.get_subset_for_channels(channels & sub_waveform.defined_channels) for sub_waveform in relevant_sub_waveforms) else: raise KeyError('Unknown channels: {}'.format(channels))
[docs]class RepetitionWaveform(Waveform): """This class allows putting multiple PulseTemplate together in one waveform on the hardware.""" def __init__(self, body: Waveform, repetition_count: int): self._body = body self._repetition_count = checked_int_cast(repetition_count) if repetition_count < 1 or not isinstance(repetition_count, int): raise ValueError('Repetition count must be an integer >0') @property def defined_channels(self) -> Set[ChannelID]: return self._body.defined_channels
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: if output_array is None: output_array = np.empty_like(sample_times) body_duration = self._body.duration time = 0 for _ in range(self._repetition_count): end = time + body_duration indices = slice(*np.searchsorted(sample_times, (float(time), float(end)), 'left')) self._body.unsafe_sample(channel=channel, sample_times=sample_times[indices] - time, output_array=output_array[indices]) time = end return output_array
@property def compare_key(self) -> Tuple[Any, int]: return self._body.compare_key, self._repetition_count @property def duration(self) -> TimeType: return self._body.duration*self._repetition_count
[docs] def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'RepetitionWaveform': return RepetitionWaveform(body=self._body.unsafe_get_subset_for_channels(channels), repetition_count=self._repetition_count)
[docs]class TransformingWaveform(Waveform): def __init__(self, inner_waveform: Waveform, transformation: Transformation): """""" self._inner_waveform = inner_waveform self._transformation = transformation # cache data of inner channels based identified and invalidated by the sample times self._cached_data = None self._cached_times = lambda: None @property def inner_waveform(self) -> Waveform: return self._inner_waveform @property def transformation(self) -> Transformation: return self._transformation @property def defined_channels(self) -> Set[ChannelID]: return self.transformation.get_output_channels(self.inner_waveform.defined_channels) @property def compare_key(self) -> Tuple[Waveform, Transformation]: return self.inner_waveform, self.transformation @property def duration(self) -> TimeType: return self.inner_waveform.duration
[docs] def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'SubsetWaveform': return SubsetWaveform(self, channel_subset=channels)
[docs] def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: if self._cached_times() is not sample_times: self._cached_data = dict() self._cached_times = ref(sample_times) if channel not in self._cached_data: inner_channels = self.transformation.get_input_channels({channel}) inner_data = {inner_channel: self.inner_waveform.unsafe_sample(inner_channel, sample_times) for inner_channel in inner_channels} outer_data = self.transformation(sample_times, inner_data) self._cached_data.update(outer_data) if output_array is None: output_array = self._cached_data[channel] else: output_array[:] = self._cached_data[channel] return output_array
class SubsetWaveform(Waveform): def __init__(self, inner_waveform: Waveform, channel_subset: Set[ChannelID]): self._inner_waveform = inner_waveform self._channel_subset = frozenset(channel_subset) @property def inner_waveform(self) -> Waveform: return self._inner_waveform @property def defined_channels(self) -> FrozenSet[ChannelID]: return self._channel_subset @property def duration(self) -> TimeType: return self.inner_waveform.duration @property def compare_key(self) -> Tuple[frozenset, Waveform]: return self.defined_channels, self.inner_waveform def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: return self.inner_waveform.get_subset_for_channels(channels) def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: return self.inner_waveform.unsafe_sample(channel, sample_times, output_array)