from typing import Optional, List, Union, Set, Dict, Sequence, Any
from numbers import Real
import itertools
import numbers
import sympy
import numpy as np
from qupulse.utils.sympy import Broadcast
from qupulse.utils.types import ChannelID
from qupulse.expressions import Expression, ExpressionScalar
from qupulse.pulses.conditions import Condition
from qupulse._program.waveforms import TableWaveform, TableWaveformEntry
from qupulse.pulses.parameters import Parameter, ParameterNotProvidedException, ParameterConstraint,\
ParameterConstrainer
from qupulse.pulses.pulse_template import AtomicPulseTemplate, MeasurementDeclaration
from qupulse.pulses.table_pulse_template import TableEntry, EntryInInit
from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform
from qupulse.serialization import Serializer, PulseRegistryType
__all__ = ["PointWaveform", "PointPulseTemplate", "PointPulseEntry", "PointWaveformEntry", "InvalidPointDimension"]
PointWaveform = TableWaveform
PointWaveformEntry = TableWaveformEntry
[docs]class PointPulseEntry(TableEntry):
[docs] def instantiate(self, parameters: Dict[str, numbers.Real], num_channels: int) -> Sequence[PointWaveformEntry]:
t = self.t.evaluate_numeric(**parameters)
vs = self.v.evaluate_numeric(**parameters)
if isinstance(vs, numbers.Number):
vs = np.full(num_channels, vs, dtype=type(vs))
elif len(vs) != num_channels:
raise InvalidPointDimension(expected=num_channels, received=len(vs))
return tuple(PointWaveformEntry(t, v, self.interp)
for v in vs)
[docs]class PointPulseTemplate(AtomicPulseTemplate, ParameterConstrainer):
def __init__(self,
time_point_tuple_list: List[EntryInInit],
channel_names: Sequence[ChannelID],
*,
parameter_constraints: Optional[List[Union[str, ParameterConstraint]]]=None,
measurements: Optional[List[MeasurementDeclaration]]=None,
identifier: Optional[str]=None,
registry: PulseRegistryType=None) -> None:
AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements)
ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints)
self._channels = tuple(channel_names)
self._entries = [PointPulseEntry(*tpt)
for tpt in time_point_tuple_list]
self._register(registry=registry)
@property
def defined_channels(self) -> Set[ChannelID]:
return set(self._channels)
@property
def point_pulse_entries(self) -> Sequence[PointPulseEntry]:
return self._entries
[docs] def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]:
data = super().get_serialization_data(serializer)
if serializer: # compatibility to old serialization routines, deprecated
data = dict()
data['time_point_tuple_list'] = [entry.get_serialization_data() for entry in self._entries]
data['channel_names'] = self._channels
if self.parameter_constraints:
data['parameter_constraints'] = [str(c) for c in self.parameter_constraints]
if self.measurement_declarations:
data['measurements'] = self.measurement_declarations
return data
@property
def duration(self) -> Expression:
return self._entries[-1].t
@property
def point_parameters(self) -> Set[str]:
return set(
var
for time, point, *_ in self._entries
for var in itertools.chain(time.variables, point.variables)
)
@property
def parameter_names(self) -> Set[str]:
return self.point_parameters | self.measurement_parameters | self.constrained_parameters
[docs] def requires_stop(self,
parameters: Dict[str, Parameter],
conditions: Dict[str, Condition]) -> bool:
try:
return any(
parameters[name].requires_stop
for name in self.parameter_names
)
except KeyError as key_error:
raise ParameterNotProvidedException(str(key_error)) from key_error
@property
def integral(self) -> Dict[ChannelID, ExpressionScalar]:
expressions = {channel: 0 for channel in self._channels}
for first_entry, second_entry in zip(self._entries[:-1], self._entries[1:]):
substitutions = {'t0': first_entry.t.sympified_expression,
't1': second_entry.t.sympified_expression}
v0 = sympy.IndexedBase(Broadcast(first_entry.v.underlying_expression, (len(self.defined_channels),)))
v1 = sympy.IndexedBase(Broadcast(second_entry.v.underlying_expression, (len(self.defined_channels),)))
for i, channel in enumerate(self._channels):
substitutions['v0'] = v0[i]
substitutions['v1'] = v1[i]
expressions[channel] += first_entry.interp.integral.sympified_expression.subs(substitutions)
expressions = {c: ExpressionScalar(expressions[c]) for c in expressions}
return expressions
[docs]class InvalidPointDimension(Exception):
def __init__(self, expected, received):
super().__init__('Expected a point of dimension {} but received {}'.format(expected, received))
self.expected = expected
self.received = received