"""This module defines plotting functionality for instantiated PulseTemplates using matplotlib.
Classes:
- PlottingNotPossibleException.
Functions:
- plot: Plot a pulse using matplotlib.
"""
from typing import Dict, Tuple, Any, Optional, Set, List, Union
from numbers import Real
import numpy as np
import warnings
import operator
import itertools
from qupulse.utils.types import ChannelID, MeasurementWindow, has_type_interface
from qupulse.pulses.pulse_template import PulseTemplate
from qupulse.pulses.parameters import Parameter
from qupulse._program.waveforms import Waveform
from qupulse._program._loop import Loop, to_waveform
__all__ = ["render", "plot", "PlottingNotPossibleException"]
[docs]def render(program: Union[Loop],
sample_rate: Real = 10.0,
render_measurements: bool = False,
time_slice: Tuple[Real, Real] = None,
plot_channels: Optional[Set[ChannelID]] = None) -> Tuple[np.ndarray, Dict[ChannelID, np.ndarray],
List[MeasurementWindow]]:
"""'Renders' a pulse program.
Samples all contained waveforms into an array according to the control flow of the program.
Args:
program: The pulse (sub)program to render. Can be represented either by a Loop object or the more
old-fashioned InstructionBlock.
sample_rate: The sample rate in GHz.
render_measurements: If True, the third return value is a list of measurement windows.
time_slice: The time slice to be rendered. If None, the entire pulse will be shown.
plot_channels: Only channels in this set are rendered. If None, all will.
Returns:
A tuple (times, values, measurements). times is a numpy.ndarray of dimensions sample_count where
containing the time values. voltages is a dictionary of one numpy.ndarray of dimensions sample_count per
defined channel containing corresponding sampled voltage values for that channel.
measurements is a sequence of all measurements where each measurement is represented by a tuple
(name, start_time, duration).
"""
if has_type_interface(program, Loop):
waveform, measurements = _render_loop(program, render_measurements=render_measurements)
else:
raise ValueError('Cannot render an object of type %r' % type(program), program)
if waveform is None:
return np.array([]), dict(), measurements
if plot_channels is None:
channels = waveform.defined_channels
else:
channels = waveform.defined_channels & plot_channels
if time_slice is None:
start_time, end_time = 0, waveform.duration
elif time_slice[1] < time_slice[0] or time_slice[0] < 0 or time_slice[1] < 0:
raise ValueError("time_slice is not valid.")
else:
start_time, end_time, *_ = time_slice
# filter measurement windows
measurements = [(name, begin, length)
for name, begin, length in measurements
if begin < end_time and begin + length > start_time]
sample_count = (end_time - start_time) * sample_rate + 1
if sample_count < 2:
raise PlottingNotPossibleException(pulse=None,
description='cannot render sequence with less than 2 data points')
if not round(float(sample_count), 10).is_integer():
warnings.warn("Sample count {sample_count} is not an integer. Will be rounded (this changes the sample rate).".format(sample_count=sample_count))
times = np.linspace(float(start_time), float(end_time), num=int(sample_count), dtype=float)
times[-1] = np.nextafter(times[-1], times[-2])
voltages = {ch: np.empty_like(times)
for ch in channels}
for ch, ch_voltage in voltages.items():
waveform.get_sampled(channel=ch, sample_times=times, output_array=ch_voltage)
return times, voltages, measurements
def _render_loop(loop: Loop,
render_measurements: bool,) -> Tuple[Waveform, List[MeasurementWindow]]:
"""Transform program into single waveform and measurement windows.
The specific implementation of render for Loop arguments."""
waveform = to_waveform(loop)
if render_measurements:
measurement_dict = loop.get_measurement_windows()
measurement_list = []
for name, (begins, lengths) in measurement_dict.items():
measurement_list.extend(zip(itertools.repeat(name), begins, lengths))
measurements = sorted(measurement_list, key=operator.itemgetter(1))
else:
measurements = []
return waveform, measurements
[docs]def plot(pulse: PulseTemplate,
parameters: Dict[str, Parameter]=None,
sample_rate: Real=10,
axes: Any=None,
show: bool=True,
plot_channels: Optional[Set[ChannelID]]=None,
plot_measurements: Optional[Set[str]]=None,
stepped: bool=True,
maximum_points: int=10**6,
time_slice: Tuple[Real, Real]=None,
**kwargs) -> Any: # pragma: no cover
"""Plots a pulse using matplotlib.
The given pulse template will first be turned into a pulse program (represented by a Loop object) with the provided
parameters. The render() function is then invoked to obtain voltage samples over the entire duration of the pulse which
are then plotted in a matplotlib figure.
Args:
pulse: The pulse to be plotted.
parameters: An optional mapping of parameter names to Parameter
objects.
sample_rate: The rate with which the waveforms are sampled for the plot in
samples per time unit. (default = 10)
axes: matplotlib Axes object the pulse will be drawn into if provided
show: If true, the figure will be shown
plot_channels: If specified only channels from this set will be plotted. If omitted all channels will be.
stepped: If true pyplot.step is used for plotting
plot_measurements: If specified measurements in this set will be plotted. If omitted no measurements will be.
maximum_points: If the sampled waveform is bigger, it is not plotted
time_slice: The time slice to be plotted. If None, the entire pulse will be shown.
kwargs: Forwarded to pyplot. Overwrites other settings.
Returns:
matplotlib.pyplot.Figure instance in which the pulse is rendered
Raises:
PlottingNotPossibleException if the sequencing is interrupted before it finishes, e.g.,
because a parameter value could not be evaluated
all Exceptions possibly raised during sequencing
"""
from matplotlib import pyplot as plt
channels = pulse.defined_channels
if parameters is None:
parameters = dict()
program = pulse.create_program(parameters=parameters,
channel_mapping={ch: ch for ch in channels},
measurement_mapping={w: w for w in pulse.measurement_names})
if program is not None:
times, voltages, measurements = render(program,
sample_rate,
render_measurements=bool(plot_measurements),
time_slice=time_slice)
else:
times, voltages, measurements = np.array([]), dict(), []
duration = 0
if times.size == 0:
warnings.warn("Pulse to be plotted is empty!")
elif times.size > maximum_points:
# todo [2018-05-30]: since it results in an empty return value this should arguably be an exception, not just a warning
warnings.warn("Sampled pulse of size {wf_len} is lager than {max_points}".format(wf_len=times.size,
max_points=maximum_points))
return None
else:
duration = times[-1]
if time_slice is None:
time_slice = (0, duration)
legend_handles = []
if axes is None:
# plot to figure
figure = plt.figure()
axes = figure.add_subplot(111)
if plot_channels is not None:
voltages = {ch: voltage
for ch, voltage in voltages.items()
if ch in plot_channels}
for ch_name, voltage in voltages.items():
label = 'channel {}'.format(ch_name)
if stepped:
line, = axes.step(times, voltage, **{**dict(where='post', label=label), **kwargs})
else:
line, = axes.plot(times, voltage, **{**dict(label=label), **kwargs})
legend_handles.append(line)
if plot_measurements:
measurement_dict = dict()
for name, begin, length in measurements:
if name in plot_measurements:
measurement_dict.setdefault(name, []).append((begin, begin+length))
color_map = plt.cm.get_cmap('plasma')
meas_colors = {name: color_map(i/len(measurement_dict))
for i, name in enumerate(measurement_dict.keys())}
for name, begin_end_list in measurement_dict.items():
for begin, end in begin_end_list:
poly = axes.axvspan(begin, end, alpha=0.2, label=name, edgecolor='black', facecolor=meas_colors[name])
legend_handles.append(poly)
axes.legend(handles=legend_handles)
max_voltage = max((max(channel, default=0) for channel in voltages.values()), default=0)
min_voltage = min((min(channel, default=0) for channel in voltages.values()), default=0)
# add some margins in the presentation
axes.set_xlim(-0.5+time_slice[0], time_slice[1] + 0.5)
voltage_difference = max_voltage-min_voltage
if voltage_difference>0:
axes.set_ylim(min_voltage - 0.1*voltage_difference, max_voltage + 0.1*voltage_difference)
axes.set_xlabel('Time (ns)')
axes.set_ylabel('Voltage (a.u.)')
if pulse.identifier:
axes.set_title(pulse.identifier)
if show:
axes.get_figure().show()
return axes.get_figure()
[docs]class PlottingNotPossibleException(Exception):
"""Indicates that plotting is not possible because the sequencing process did not translate
the entire given PulseTemplate structure."""
def __init__(self, pulse, description = None) -> None:
super().__init__()
self.pulse = pulse
self.description = description
def __str__(self) -> str:
if self.description is None:
return "Plotting is not possible. There are parameters which cannot be computed."
else:
return "Plotting is not possible: %s." % self.description