from typing import Optional, List, Union, Set, Dict, Sequence, Any, Tuple
from numbers import Real
import itertools
import numbers
import sympy
import numpy as np
from qupulse.utils.sympy import IndexedBroadcast
from qupulse.utils.types import ChannelID
from qupulse.expressions import Expression, ExpressionScalar
from qupulse.program.waveforms import TableWaveform, TableWaveformEntry
from qupulse.pulses.parameters import ParameterConstraint, ParameterConstrainer
from qupulse.pulses.pulse_template import AtomicPulseTemplate, MeasurementDeclaration
from qupulse.pulses.table_pulse_template import TableEntry, EntryInInit
from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform
from qupulse.serialization import Serializer, PulseRegistryType
__all__ = ["PointWaveform", "PointPulseTemplate", "PointPulseEntry", "PointWaveformEntry", "InvalidPointDimension"]
PointWaveform = TableWaveform
PointWaveformEntry = TableWaveformEntry
[docs]class PointPulseEntry(TableEntry):
[docs] def instantiate(self, parameters: Dict[str, numbers.Real], num_channels: int) -> Sequence[PointWaveformEntry]:
t = self.t.evaluate_in_scope(parameters)
vs = self.v.evaluate_in_scope(parameters)
if isinstance(vs, numbers.Number):
vs = (vs,) * num_channels
elif len(vs) != num_channels:
raise InvalidPointDimension(expected=num_channels, received=len(vs))
return tuple(PointWaveformEntry(t, v, self.interp)
for v in vs)
[docs]class PointPulseTemplate(AtomicPulseTemplate, ParameterConstrainer):
def __init__(self,
time_point_tuple_list: List[EntryInInit],
channel_names: Sequence[ChannelID],
*,
parameter_constraints: Optional[List[Union[str, ParameterConstraint]]]=None,
measurements: Optional[List[MeasurementDeclaration]]=None,
identifier: Optional[str]=None,
registry: PulseRegistryType=None) -> None:
AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements)
ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints)
self._channels = tuple(channel_names)
self._entries = [PointPulseEntry(*tpt)
for tpt in time_point_tuple_list]
self._register(registry=registry)
@property
def defined_channels(self) -> Set[ChannelID]:
return set(self._channels)
@property
def point_pulse_entries(self) -> Sequence[PointPulseEntry]:
return self._entries
[docs] def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]:
data = super().get_serialization_data(serializer)
if serializer: # compatibility to old serialization routines, deprecated
data = dict()
data['time_point_tuple_list'] = [entry.get_serialization_data() for entry in self._entries]
data['channel_names'] = self._channels
if self.parameter_constraints:
data['parameter_constraints'] = [str(c) for c in self.parameter_constraints]
if self.measurement_declarations:
data['measurements'] = self.measurement_declarations
return data
@property
def duration(self) -> Expression:
return self._entries[-1].t
@property
def point_parameters(self) -> Set[str]:
return set(
var
for time, point, *_ in self._entries
for var in itertools.chain(time.variables, point.variables)
)
@property
def parameter_names(self) -> Set[str]:
return self.point_parameters | self.measurement_parameters | self.constrained_parameters
@property
def integral(self) -> Dict[ChannelID, ExpressionScalar]:
expressions = {}
shape = (len(self.defined_channels),)
for i, channel in enumerate(self._channels):
def value_trafo(v):
try:
return v.underlying_expression[i]
except TypeError:
return IndexedBroadcast(v.underlying_expression, shape, i)
pre_entry = TableEntry(0, self._entries[0].v, None)
entries = [pre_entry] + self._entries
expressions[channel] = TableEntry._sequence_integral(entries, expression_extractor=value_trafo)
return expressions
def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]:
t = self._AS_EXPRESSION_TIME
shape = (len(self.defined_channels),)
expressions = {}
for i, channel in enumerate(self._channels):
def value_trafo(v):
try:
return v.underlying_expression[i]
except TypeError:
return IndexedBroadcast(v.underlying_expression, shape, i)
pre_value = value_trafo(self._entries[0].v)
post_value = value_trafo(self._entries[-1].v)
pw = TableEntry._sequence_as_expression(self._entries,
expression_extractor=value_trafo,
t=t,
post_value=post_value,
pre_value=pre_value)
expressions[channel] = pw
return expressions
@property
def initial_values(self) -> Dict[ChannelID, ExpressionScalar]:
shape = (len(self._channels),)
return {
ch: ExpressionScalar(IndexedBroadcast(self._entries[0].v, shape, ch_idx))
for ch_idx, ch in enumerate(self._channels)
}
@property
def final_values(self) -> Dict[ChannelID, ExpressionScalar]:
shape = (len(self._channels),)
return {
ch: ExpressionScalar(IndexedBroadcast(self._entries[-1].v, shape, ch_idx))
for ch_idx, ch in enumerate(self._channels)
}
[docs]class InvalidPointDimension(Exception):
def __init__(self, expected, received):
super().__init__('Expected a point of dimension {} but received {}'.format(expected, received))
self.expected = expected
self.received = received