"""This module defines plotting functionality for instantiated PulseTemplates using matplotlib.
Classes:
- PlottingNotPossibleException.
Functions:
- plot: Plot a pulse using matplotlib.
"""
from typing import Dict, Tuple, Any, Optional, Set, List, Union
from numbers import Real
import numpy as np
import warnings
import operator
import itertools
from qupulse._program import waveforms
from qupulse.utils.types import ChannelID, MeasurementWindow, has_type_interface
from qupulse.pulses.pulse_template import PulseTemplate
from qupulse.pulses.parameters import Parameter
from qupulse._program.waveforms import Waveform
from qupulse._program._loop import Loop, to_waveform
__all__ = ["render", "plot", "PlottingNotPossibleException"]
[docs]def render(program: Union[Loop],
sample_rate: Real = 10.0,
render_measurements: bool = False,
time_slice: Tuple[Real, Real] = None,
plot_channels: Optional[Set[ChannelID]] = None) -> Tuple[np.ndarray, Dict[ChannelID, np.ndarray],
List[MeasurementWindow]]:
"""'Renders' a pulse program.
Samples all contained waveforms into an array according to the control flow of the program.
Args:
program: The pulse (sub)program to render. Can be represented either by a Loop object or the more
old-fashioned InstructionBlock.
sample_rate: The sample rate in GHz.
render_measurements: If True, the third return value is a list of measurement windows.
time_slice: The time slice to be rendered. If None, the entire pulse will be shown.
plot_channels: Only channels in this set are rendered. If None, all will.
Returns:
A tuple (times, values, measurements). times is a numpy.ndarray of dimensions sample_count where
containing the time values. voltages is a dictionary of one numpy.ndarray of dimensions sample_count per
defined channel containing corresponding sampled voltage values for that channel.
measurements is a sequence of all measurements where each measurement is represented by a tuple
(name, start_time, duration).
"""
if has_type_interface(program, Loop):
waveform, measurements = _render_loop(program, render_measurements=render_measurements)
else:
raise ValueError('Cannot render an object of type %r' % type(program), program)
if waveform is None:
return np.array([]), dict(), measurements
if plot_channels is None:
channels = waveform.defined_channels
else:
channels = waveform.defined_channels & plot_channels
if time_slice is None:
start_time, end_time = 0, waveform.duration
elif time_slice[1] < time_slice[0] or time_slice[0] < 0 or time_slice[1] < 0:
raise ValueError("time_slice is not valid.")
else:
start_time, end_time, *_ = time_slice
# filter measurement windows
measurements = [(name, begin, length)
for name, begin, length in measurements
if begin < end_time and begin + length > start_time]
sample_count = (end_time - start_time) * sample_rate + 1
if sample_count < 2:
raise PlottingNotPossibleException(pulse=None,
description='cannot render sequence with less than 2 data points')
if not round(float(sample_count), 10).is_integer():
warnings.warn("Sample count {sample_count} is not an integer. Will be rounded (this changes the sample rate).".format(sample_count=sample_count))
times = np.linspace(float(start_time), float(end_time), num=int(sample_count), dtype=float)
times[-1] = np.nextafter(times[-1], times[-2])
voltages = {ch: waveforms._ALLOCATION_FUNCTION(times, **waveforms._ALLOCATION_FUNCTION_KWARGS)
for ch in channels}
for ch, ch_voltage in voltages.items():
waveform.get_sampled(channel=ch, sample_times=times, output_array=ch_voltage)
return times, voltages, measurements
def _render_loop(loop: Loop,
render_measurements: bool,) -> Tuple[Waveform, List[MeasurementWindow]]:
"""Transform program into single waveform and measurement windows.
The specific implementation of render for Loop arguments."""
waveform = to_waveform(loop)
if render_measurements:
measurement_dict = loop.get_measurement_windows()
measurement_list = []
for name, (begins, lengths) in measurement_dict.items():
measurement_list.extend(zip(itertools.repeat(name), begins, lengths))
measurements = sorted(measurement_list, key=operator.itemgetter(1))
else:
measurements = []
return waveform, measurements
[docs]def plot(pulse: PulseTemplate,
parameters: Dict[str, Parameter]=None,
sample_rate: Real=10,
axes: Any=None,
show: bool=True,
plot_channels: Optional[Set[ChannelID]]=None,
plot_measurements: Optional[Set[str]]=None,
stepped: bool=True,
maximum_points: int=10**6,
time_slice: Tuple[Real, Real]=None,
**kwargs) -> Any: # pragma: no cover
"""Plots a pulse using matplotlib.
The given pulse template will first be turned into a pulse program (represented by a Loop object) with the provided
parameters. The render() function is then invoked to obtain voltage samples over the entire duration of the pulse which
are then plotted in a matplotlib figure.
Args:
pulse: The pulse to be plotted.
parameters: An optional mapping of parameter names to Parameter
objects.
sample_rate: The rate with which the waveforms are sampled for the plot in
samples per time unit. (default = 10)
axes: matplotlib Axes object the pulse will be drawn into if provided
show: If true, the figure will be shown
plot_channels: If specified only channels from this set will be plotted. If omitted all channels will be.
stepped: If true pyplot.step is used for plotting
plot_measurements: If specified measurements in this set will be plotted. If omitted no measurements will be.
maximum_points: If the sampled waveform is bigger, it is not plotted
time_slice: The time slice to be plotted. If None, the entire pulse will be shown.
kwargs: Forwarded to pyplot. Overwrites other settings.
Returns:
matplotlib.pyplot.Figure instance in which the pulse is rendered
Raises:
PlottingNotPossibleException if the sequencing is interrupted before it finishes, e.g.,
because a parameter value could not be evaluated
all Exceptions possibly raised during sequencing
"""
from matplotlib import pyplot as plt
channels = pulse.defined_channels
if parameters is None:
parameters = dict()
program = pulse.create_program(parameters=parameters,
channel_mapping={ch: ch for ch in channels},
measurement_mapping={w: w for w in pulse.measurement_names})
if program is not None:
times, voltages, measurements = render(program,
sample_rate,
render_measurements=bool(plot_measurements),
time_slice=time_slice)
else:
times, voltages, measurements = np.array([]), dict(), []
duration = 0
if times.size == 0:
warnings.warn("Pulse to be plotted is empty!")
elif times.size > maximum_points:
# todo [2018-05-30]: since it results in an empty return value this should arguably be an exception, not just a warning
warnings.warn("Sampled pulse of size {wf_len} is lager than {max_points}".format(wf_len=times.size,
max_points=maximum_points))
return None
else:
duration = times[-1]
if time_slice is None:
time_slice = (0, duration)
legend_handles = []
if axes is None:
# plot to figure
figure = plt.figure()
axes = figure.add_subplot(111)
if plot_channels is not None:
voltages = {ch: voltage
for ch, voltage in voltages.items()
if ch in plot_channels}
for ch_name, voltage in voltages.items():
label = 'channel {}'.format(ch_name)
if stepped:
line, = axes.step(times, voltage, **{**dict(where='post', label=label), **kwargs})
else:
line, = axes.plot(times, voltage, **{**dict(label=label), **kwargs})
legend_handles.append(line)
if plot_measurements:
measurement_dict = dict()
for name, begin, length in measurements:
if name in plot_measurements:
measurement_dict.setdefault(name, []).append((begin, begin+length))
color_map = plt.cm.get_cmap('plasma')
meas_colors = {name: color_map(i/len(measurement_dict))
for i, name in enumerate(measurement_dict.keys())}
for name, begin_end_list in measurement_dict.items():
for begin, end in begin_end_list:
poly = axes.axvspan(begin, end, alpha=0.2, label=name, edgecolor='black', facecolor=meas_colors[name])
legend_handles.append(poly)
axes.legend(handles=legend_handles)
max_voltage = max((max(channel, default=0) for channel in voltages.values()), default=0)
min_voltage = min((min(channel, default=0) for channel in voltages.values()), default=0)
# add some margins in the presentation
axes.set_xlim(-0.5+time_slice[0], time_slice[1] + 0.5)
voltage_difference = max_voltage-min_voltage
if voltage_difference>0:
axes.set_ylim(min_voltage - 0.1*voltage_difference, max_voltage + 0.1*voltage_difference)
axes.set_xlabel('Time (ns)')
axes.set_ylabel('Voltage (a.u.)')
if pulse.identifier:
axes.set_title(pulse.identifier)
if show:
axes.get_figure().show()
return axes.get_figure()
[docs]class PlottingNotPossibleException(Exception):
"""Indicates that plotting is not possible because the sequencing process did not translate
the entire given PulseTemplate structure."""
def __init__(self, pulse, description = None) -> None:
super().__init__()
self.pulse = pulse
self.description = description
def __str__(self) -> str:
if self.description is None:
return "Plotting is not possible. There are parameters which cannot be computed."
else:
return "Plotting is not possible: %s." % self.description