Source code for qupulse._program.seqc

"""This module contains the ZI HDAWG compatible description of programs. There is no code in here that interacts with
hardware directly.

The public interface to all functionality is given by `HDAWGProgramManager`. This class can create seqc source code
which contains multiple programs and allows switching between these with the user registers of a device,

Furthermore:
- `SEQCNode`: AST of a subset of sequencing C
- `loop_to_seqc`: conversion of `Loop` objects to this subset in a clever way
- `BinaryWaveform`: Bundles functionality of handling segments in a native way.
- `WaveformMemory`: Functionality to sync waveforms to the device (via the LabOne user folder)
- `ProgramWaveformManager` and `HDAWGProgramEntry`: Program wise handling of waveforms and seqc-code
classes that convert `Loop` objects"""

from typing import Optional, Union, Sequence, Dict, Iterator, Tuple, Callable, NamedTuple, MutableMapping, Mapping,\
    Iterable, Any
from types import MappingProxyType
import abc
import itertools
import inspect
import logging
import os.path
import hashlib
from collections import OrderedDict
import string
import warnings
import numbers

import numpy as np
from pathlib import Path

from qupulse.utils.types import ChannelID, TimeType
from qupulse.utils import replace_multiple
from qupulse._program.waveforms import Waveform
from qupulse._program._loop import Loop
from qupulse._program.volatile import VolatileRepetitionCount, VolatileProperty
from qupulse.hardware.awgs.base import ProgramEntry

try:
    import zhinst.utils
except ImportError:
    zhinst = None


__all__ = ["HDAWGProgramManager"]


class BinaryWaveform:
    """This class represents a sampled waveform in the native HDAWG format as returned
    by zhinst.utils.convert_awg_waveform.

    BinaryWaveform.data can be uploaded directly to {device]/awgs/{awg}/waveform/waves/{wf}

    `to_csv_compatible_table` can be used to create a compatible compact csv file (with marker data included)
    """
    __slots__ = ('data',)

    def __init__(self, data: np.ndarray):
        """ always use both channels?

        Args:
            data: data as returned from zhinst.utils.convert_awg_waveform
        """
        n_quantum, remainder = divmod(data.size, 3 * 16)
        assert n_quantum > 1, "Waveform too short (min len is 32)"
        assert remainder == 0, "Waveform has not a valid length"
        assert data.dtype is np.dtype('uint16')
        assert np.all(data[2::3] < 16), "invalid marker data"
        assert data.ndim == 1, "Data not one dimensional"

        self.data = data
        self.data.flags.writeable = False

    @property
    def ch1(self):
        return self.data[::3]

    @property
    def ch2(self):
        return self.data[1::3]

    @property
    def marker_data(self):
        return self.data[2::3]

    @property
    def markers_ch1(self):
        return np.bitwise_and(self.marker_data, 0b0011)

    @property
    def markers_ch2(self):
        return np.bitwise_and(self.marker_data, 0b1100)

    @classmethod
    def from_sampled(cls, ch1: Optional[np.ndarray], ch2: Optional[np.ndarray],
                     markers: Tuple[Optional[np.ndarray], Optional[np.ndarray],
                                    Optional[np.ndarray], Optional[np.ndarray]]) -> 'BinaryWaveform':
        """Combines the sampled and scaled waveform data into a single binary compatible waveform

        Args:
            ch1: sampled waveform scaled to full range (-1., 1.)
            ch2: sampled waveform scaled to full range (-1., 1.)
            markers: (ch1_front_marker, ch1_dio_marker, ch2_front_marker, ch2_dio_marker)

        Returns:

        """
        all_input = (ch1, ch2, *markers)
        assert any(x is not None for x in all_input)
        size = {x.size for x in all_input if x is not None}
        assert len(size) == 1, "Inputs have incompatible dimension"
        size, = size
        if ch1 is None:
            ch1 = np.zeros(size)
        if ch2 is None:
            ch2 = np.zeros(size)
        marker_data = np.zeros(size, dtype=np.uint16)
        for idx, marker in enumerate(markers):
            if marker is not None:
                marker_data += np.uint16((marker > 0) * 2**idx)
        return cls(zhinst.utils.convert_awg_waveform(ch1, ch2, marker_data))

    def __len__(self):
        return self.data.size // 3

    def __eq__(self, other):
        return np.array_equal(self.data, other.data)

    def __hash__(self):
        return hash(bytes(self.data))

    def fingerprint(self) -> str:
        """This fingerprint is runtime independent"""
        return hashlib.sha256(self.data).hexdigest()

    def to_csv_compatible_table(self):
        """The integer values in that file should be 18-bit unsigned integers with the two least significant bits
        being the markers. The values are mapped to 0 => -FS, 262143 => +FS, with FS equal to the full scale.

        >>> np.savetxt(waveform_dir, binary_waveform.to_csv_compatible_table(), fmt='%u')
        """
        table = np.zeros((len(self), 2), dtype=np.uint32)
        table[:, 0] = self.ch1
        table[:, 1] = self.ch2
        np.left_shift(table, 2, out=table)
        table[:, 0] += self.markers_ch1
        table[:, 1] += self.markers_ch2

        return table


class ConcatenatedWaveform:
    def __init__(self):
        """Handle the concatenation of multiple binary waveforms to create a big indexable waveform."""
        self._concatenated = []
        self._as_binary = None

    def __bool__(self):
        return bool(self._concatenated)

    def is_finalized(self):
        return self._as_binary is not None or self._concatenated is None

    def as_binary(self) -> Optional[BinaryWaveform]:
        assert self.is_finalized()
        return self._as_binary

    def append(self, binary_waveform):
        assert not self.is_finalized()
        self._concatenated.append(binary_waveform)

    def finalize(self):
        assert not self.is_finalized()
        if self._concatenated:
            concatenated_data = np.concatenate([wf.data for wf in self._concatenated])
            self._as_binary = BinaryWaveform(concatenated_data)
        else:
            self._concatenated = None

    def clear(self):
        if self._concatenated is None:
            self._concatenated = []
        else:
            self._concatenated.clear()
        self._as_binary = None


class WaveformFileSystem:
    logger = logging.getLogger('qupulse.hdawg.waveforms')

    def __init__(self, path: Path):
        """This class coordinates multiple AWGs (channel pairs) using the same file system to store the waveforms.

        Args:
            path: Waveforms are stored here
        """
        self._required = {}
        self._path = path

    def sync(self, client: 'WaveformMemory', waveforms: Mapping[str, BinaryWaveform], **kwargs):
        """Write the required waveforms to the filesystem."""
        self._required[id(client)] = waveforms
        self._sync(**kwargs)

    def _sync(self, delete=True, write_all=False):
        to_save = {self._path.joinpath(file_name): binary
                   for d in self._required.values()
                   for file_name, binary in d.items()}

        for existing_file in self._path.iterdir():
            if not existing_file.is_file():
                pass
            elif existing_file in to_save:
                if not write_all:
                    self.logger.debug('Skipping %r', existing_file.name)
                    to_save.pop(existing_file)
            elif delete:
                try:
                    self.logger.debug('Deleting %r', existing_file.name)
                    existing_file.unlink()
                except OSError:
                    self.logger.exception("Error deleting: %r", existing_file.name)

        for file_name, binary_waveform in to_save.items():
            table = binary_waveform.to_csv_compatible_table()
            np.savetxt(file_name, table, '%u')
            self.logger.debug('Wrote %r', file_name)


class WaveformMemory:
    """Global waveform "memory" representation (currently the file system)"""
    CONCATENATED_WAVEFORM_TEMPLATE = '{program_name}_concatenated_waveform'
    SHARED_WAVEFORM_TEMPLATE = '{program_name}_shared_waveform_{hash}'
    WF_PLACEHOLDER_TEMPLATE = '*{id}*'
    FILE_NAME_TEMPLATE = '{hash}.csv'

    _WaveInfo = NamedTuple('_WaveInfo', [('wave_name', str),
                                         ('wave_placeholder', str),
                                         ('file_name', str),
                                         ('binary_waveform', BinaryWaveform)])

    def __init__(self):
        self.shared_waveforms = OrderedDict()  # type: MutableMapping[BinaryWaveform, set]
        self.concatenated_waveforms = OrderedDict()  # type: MutableMapping[str, ConcatenatedWaveform]

    def clear(self):
        self.shared_waveforms.clear()
        self.concatenated_waveforms.clear()

    def _shared_waveforms_iter(self) -> Iterator[_WaveInfo]:
        for wf, program_set in self.shared_waveforms.items():
            if program_set:
                wave_hash = wf.fingerprint()
                wave_name = self.SHARED_WAVEFORM_TEMPLATE.format(program_name='_'.join(program_set),
                                                                 hash=wave_hash)
                wave_placeholder = self.WF_PLACEHOLDER_TEMPLATE.format(id=id(program_set))
                file_name = self.FILE_NAME_TEMPLATE.format(hash=wave_hash)
                yield self._WaveInfo(wave_name, wave_placeholder, file_name, wf)

    def _concatenated_waveforms_iter(self) -> Iterator[_WaveInfo]:
        for program_name, concatenated_waveform in self.concatenated_waveforms.items():
            if concatenated_waveform:
                wave_hash = concatenated_waveform.as_binary().fingerprint()
                wave_placeholder = self.WF_PLACEHOLDER_TEMPLATE.format(id=id(concatenated_waveform))
                wave_name = self.CONCATENATED_WAVEFORM_TEMPLATE.format(program_name=program_name)
                file_name = self.FILE_NAME_TEMPLATE.format(hash=wave_hash)
                yield self._WaveInfo(wave_name, wave_placeholder, file_name, concatenated_waveform.as_binary())

    def waveform_name_replacements(self) -> Dict[str, str]:
        """replace place holders of complete seqc program with

        >>> waveform_name_translation = waveform_memory.waveform_name_replacements()
        >>> seqc_program = qupulse.utils.replace_multiple(seqc_program, waveform_name_translation)
        """
        translation = {}
        for wave_info in self._shared_waveforms_iter():
            translation[wave_info.wave_placeholder] = wave_info.wave_name

        for wave_info in self._concatenated_waveforms_iter():
            translation[wave_info.wave_placeholder] = wave_info.wave_name
        return translation

    def waveform_declaration(self) -> str:
        """Produces a string that declares all needed waveforms.
        It is needed to know the waveform index in case we want to update a waveform during playback."""
        declarations = []
        for wave_info in self._concatenated_waveforms_iter():
            declarations.append(
                'wave {wave_name} = "{file_name}";'.format(wave_name=wave_info.wave_name,
                                                           file_name=wave_info.file_name.replace('.csv', ''))
            )

        for wave_info in self._shared_waveforms_iter():
            declarations.append(
                'wave {wave_name} = "{file_name}";'.format(wave_name=wave_info.wave_name,
                                                           file_name=wave_info.file_name.replace('.csv', ''))
            )
        return '\n'.join(declarations)

    def sync_to_file_system(self, file_system: WaveformFileSystem):

        to_save = {wave_info.file_name: wave_info.binary_waveform
                   for wave_info in itertools.chain(self._concatenated_waveforms_iter(),
                                                    self._shared_waveforms_iter())}
        file_system.sync(self, to_save)


class ProgramWaveformManager:
    """Manages waveforms of a program"""
    def __init__(self, name, memory: WaveformMemory):
        self._program_name = name
        self._memory = memory

        assert self._program_name not in self._memory.concatenated_waveforms
        assert all(self._program_name not in programs for programs in self._memory.shared_waveforms.values())
        self._memory.concatenated_waveforms[self._program_name] = ConcatenatedWaveform()

    @property
    def program_name(self) -> str:
        return self._program_name

    def clear_requested(self):
        for programs in self._memory.shared_waveforms.values():
            programs.discard(self._program_name)
        self._memory.concatenated_waveforms[self._program_name].clear()

    def request_shared(self, binary_waveform: BinaryWaveform) -> str:
        """Register waveform if not already registered and return a unique identifier placeholder.

        The unique identifier currently is computed from the id of the set which stores all programs using this
        waveform.
        """
        program_set = self._memory.shared_waveforms.setdefault(binary_waveform, set())
        program_set.add(self._program_name)
        return self._memory.WF_PLACEHOLDER_TEMPLATE.format(id=id(program_set))

    def request_concatenated(self, binary_waveform: BinaryWaveform) -> str:
        """Append the waveform to the concatenated waveform"""
        bin_wf_list = self._memory.concatenated_waveforms[self._program_name]
        bin_wf_list.append(binary_waveform)
        return self._memory.WF_PLACEHOLDER_TEMPLATE.format(id=id(bin_wf_list))

    def finalize(self):
        self._memory.concatenated_waveforms[self._program_name].finalize()

    def prepare_delete(self):
        """Delete all references in waveform memory to this program. Cannot be used afterwards."""
        self.clear_requested()
        del self._memory.concatenated_waveforms[self._program_name]


class UserRegister:
    """This class is a helper class to avoid errors due to 0 and 1 based register indexing"""
    __slots__ = ('_zero_based_value',)

    def __init__(self, *, zero_based_value: int = None, one_based_value: int = None):
        assert None in (zero_based_value, one_based_value)
        assert isinstance(zero_based_value, int) or isinstance(one_based_value, int)

        if one_based_value is not None:
            assert one_based_value > 0, "A one based value needs to be larger zero"
            self._zero_based_value = one_based_value - 1
        else:
            self._zero_based_value = zero_based_value

    @classmethod
    def from_seqc(cls, value: int) -> 'UserRegister':
        return cls(zero_based_value=value)

    def to_seqc(self) -> int:
        return self._zero_based_value

    @classmethod
    def from_labone(cls, value: int) -> 'UserRegister':
        return cls(zero_based_value=value)

    def to_labone(self) -> int:
        return self._zero_based_value

    @classmethod
    def from_web_interface(cls, value: int) -> 'UserRegister':
        return cls(one_based_value=value)

    def to_web_interface(self) -> int:
        return self._zero_based_value + 1

    def __hash__(self):
        return hash(self._zero_based_value)

    def __eq__(self, other):
        return self._zero_based_value == getattr(other, '_zero_based_value', None)

    def __repr__(self):
        return 'UserRegister(zero_based_value={zero_based_value})'.format(zero_based_value=self._zero_based_value)

    def __format__(self, format_spec: str) -> str:
        if format_spec in ('zero_based', 'seqc', 'labone', 'lab_one'):
            return str(self.to_seqc())
        elif format_spec in ('one_based', 'web', 'web_interface'):
            return str(self.to_web_interface())
        elif format_spec in ('repr', 'r'):
            return repr(self)
        else:
            raise ValueError('Invalid format spec for UserRegister: ', format_spec)


class UserRegisterManager:
    """This class keeps track of the user registered that are used in a certain context"""
    def __init__(self, available: Iterable[UserRegister], name_template: str):
        assert 'register' in (x[1] for x in string.Formatter().parse(name_template))

        self._available = set(available)
        self._name_template = name_template
        self._used = {}

    def request(self, obj) -> str:
        """Request a user register name to store object. If an object that evaluates equal to obj was requested before
        the name name is returned.

        Args:
            obj: Object to store

        Returns:
            Name of the variable with the user register

        Raises:
            Value error if no register is available
        """
        for register, registered_obj in self._used.items():
            if obj == registered_obj:
                return self._name_template.format(register=register)
        if self._available:
            register = self._available.pop()
            self._used[register] = obj
            return self._name_template.format(register=register)
        else:
            raise ValueError("No register available for %r" % obj)

    def iter_used_register_names(self) -> Iterator[Tuple[UserRegister, str]]:
        """

        Returns:
            An iterator over (register index, register name) pairs
        """
        return ((register, self._name_template.format(register=register)) for register in self._used.keys())

    def iter_used_register_values(self) -> Iterable[Tuple[UserRegister, Any]]:
        return self._used.items()


class HDAWGProgramEntry(ProgramEntry):
    USER_REG_NAME_TEMPLATE = 'user_reg_{register:seqc}'

    def __init__(self, loop: Loop, selection_index: int, waveform_memory: WaveformMemory, program_name: str,
                 channels: Tuple[Optional[ChannelID], Optional[ChannelID]],
                 markers: Tuple[Optional[ChannelID], Optional[ChannelID], Optional[ChannelID], Optional[ChannelID]],
                 amplitudes: Tuple[float, float],
                 offsets: Tuple[float, float],
                 voltage_transformations: Tuple[Optional[Callable], Optional[Callable]],
                 sample_rate: TimeType):
        super().__init__(loop, channels=channels, markers=markers,
                         amplitudes=amplitudes,
                         offsets=offsets,
                         voltage_transformations=voltage_transformations,
                         sample_rate=sample_rate)
        for waveform, (sampled_channels, sampled_markers) in self._waveforms.items():
            self._waveforms[waveform] = BinaryWaveform.from_sampled(*sampled_channels, sampled_markers)

        self._waveform_manager = ProgramWaveformManager(program_name, waveform_memory)
        self.selection_index = selection_index
        self._trigger_wait_code = None
        self._seqc_node = None
        self._seqc_source = None
        self._var_declarations = None
        self._user_registers = None
        self._user_register_source = None

    def compile(self,
                min_repetitions_for_for_loop: int,
                min_repetitions_for_shared_wf: int,
                indentation: str,
                trigger_wait_code: str,
                available_registers: Iterable[UserRegister]):
        """Compile the loop representation to an internal sequencing c one using `loop_to_seqc`

        Args:
            min_repetitions_for_for_loop: See `loop_to_seqc`
            min_repetitions_for_shared_wf: See `loop_to_seqc`
            indentation: Each line is prefixed with this
            trigger_wait_code: The code is put before the playback start
            available_registers
        Returns:

        """
        pos_var_name = 'pos'

        if self._seqc_node:
            self._waveform_manager.clear_requested()

        user_registers = UserRegisterManager(available_registers, self.USER_REG_NAME_TEMPLATE)

        self._seqc_node = loop_to_seqc(self._loop,
                                       min_repetitions_for_for_loop=min_repetitions_for_for_loop,
                                       min_repetitions_for_shared_wf=min_repetitions_for_shared_wf,
                                       waveform_to_bin=self.get_binary_waveform,
                                       user_registers=user_registers)

        self._user_register_source = '\n'.join(
            '{indentation}var {user_reg_name} = getUserReg({register});'.format(indentation=indentation,
                                                                                user_reg_name=user_reg_name,
                                                                                register=register.to_seqc())
            for register, user_reg_name in user_registers.iter_used_register_names()
        )
        self._user_registers = user_registers

        self._var_declarations = '{indentation}var {pos_var_name} = 0;'.format(pos_var_name=pos_var_name,
                                                                               indentation=indentation)
        self._trigger_wait_code = indentation + trigger_wait_code
        self._seqc_source = '\n'.join(self._seqc_node.to_source_code(self._waveform_manager,
                                                                     map(str, itertools.count(1)),
                                                                     line_prefix=indentation,
                                                                     pos_var_name=pos_var_name))
        self._waveform_manager.finalize()

    @property
    def seqc_node(self) -> 'SEQCNode':
        assert self._seqc_node is not None, "compile not called"
        return self._seqc_node

    @property
    def seqc_source(self) -> str:
        assert self._seqc_source is not None, "compile not called"
        return '\n'.join([self._var_declarations,
                          self._user_register_source,
                          self._trigger_wait_code,
                          self._seqc_source])

    def volatile_repetition_counts(self) -> Iterable[Tuple[UserRegister, VolatileRepetitionCount]]:
        """
        Returns:
            An iterator over the register and parameter
        """
        assert self._user_registers is not None, "compile not called"
        return self._user_registers.iter_used_register_values()

    @property
    def name(self) -> str:
        return self._waveform_manager.program_name

    def parse_to_seqc(self, waveform_memory):
        raise NotImplementedError()

    def get_binary_waveform(self, waveform: Waveform) -> BinaryWaveform:
        return self._waveforms[waveform]

    def prepare_delete(self):
        """Delete all references to this program. Cannot be used afterwards"""
        self._waveform_manager.prepare_delete()
        self._seqc_node = None
        self._seqc_source = None


[docs]class HDAWGProgramManager: """This class contains everything that is needed to create the final seqc program and provides an interface to write the required waveforms to the file system. It does not talk to the device.""" GLOBAL_CONSTS = dict(PROG_SEL_REGISTER=UserRegister(zero_based_value=0), TRIGGER_REGISTER=UserRegister(zero_based_value=1), TRIGGER_RESET_MASK=bin(1 << 15), PROG_SEL_NONE=0, NO_RESET_MASK=bin(1 << 15), PROG_SEL_MASK=bin((1 << 15) - 1), IDLE_WAIT_CYCLES=300) PROGRAM_FUNCTION_NAME_TEMPLATE = '{program_name}_function' INIT_PROGRAM_SWITCH = '// INIT program switch.\nvar prog_sel = 0;' WAIT_FOR_SOFTWARE_TRIGGER = "waitForSoftwareTrigger();" SOFTWARE_WAIT_FOR_TRIGGER_FUNCTION_DEFINITION = ( 'void waitForSoftwareTrigger() {\n' ' while (true) {\n' ' var trigger_register = getUserReg(TRIGGER_REGISTER);\n' ' if (trigger_register & TRIGGER_RESET_MASK) setUserReg(TRIGGER_REGISTER, 0);\n' ' if (trigger_register) return;\n' ' }\n' '}\n' ) def __init__(self): self._waveform_memory = WaveformMemory() self._programs = OrderedDict() # type: MutableMapping[str, HDAWGProgramEntry] @property def waveform_memory(self): return self._waveform_memory def _get_low_unused_index(self): existing = {entry.selection_index for entry in self._programs.values()} for idx in itertools.count(): if idx not in existing and idx != self.GLOBAL_CONSTS['PROG_SEL_NONE']: return idx
[docs] def add_program(self, name: str, loop: Loop, channels: Tuple[Optional[ChannelID], Optional[ChannelID]], markers: Tuple[Optional[ChannelID], Optional[ChannelID], Optional[ChannelID], Optional[ChannelID]], amplitudes: Tuple[float, float], offsets: Tuple[float, float], voltage_transformations: Tuple[Optional[Callable], Optional[Callable]], sample_rate: TimeType): """Register the given program and translate it to seqc. TODO: Add an interface to change the trigger mode Args: name: Human readable name of the program (used f.i. for the function name) loop: The program to upload channels: see AWG.upload markers: see AWG.upload amplitudes: Used to sample the waveforms offsets: Used to sample the waveforms voltage_transformations: see AWG.upload sample_rate: Used to sample the waveforms """ assert name not in self._programs selection_index = self._get_low_unused_index() # TODO: verify total number of registers available_registers = [UserRegister.from_seqc(idx) for idx in range(2, 16)] program_entry = HDAWGProgramEntry(loop, selection_index, self._waveform_memory, name, channels, markers, amplitudes, offsets, voltage_transformations, sample_rate) # TODO: de-hardcode these parameters and put compilation in seperate function program_entry.compile(20, 1000, ' ', self.WAIT_FOR_SOFTWARE_TRIGGER, available_registers=available_registers) self._programs[name] = program_entry
[docs] def get_register_values(self, name: str) -> Mapping[UserRegister, int]: return {register: int(parameter) for register, parameter in self._programs[name].volatile_repetition_counts()}
[docs] def get_register_values_to_update_volatile_parameters(self, name: str, parameters: Mapping[str, numbers.Number]) -> Mapping[UserRegister, int]: """ Args: name: Program name parameters: new values for volatile parameters Returns: A dict user_register->value that reflects the new parameter values """ program_entry = self._programs[name] result = {} for register, volatile_repetition in program_entry.volatile_repetition_counts(): new_value = volatile_repetition.update_volatile_dependencies(parameters) result[register] = new_value return result
@property def programs(self) -> Mapping[str, HDAWGProgramEntry]: return MappingProxyType(self._programs)
[docs] def remove(self, name: str) -> None: self._programs.pop(name).prepare_delete()
[docs] def clear(self) -> None: self._waveform_memory.clear() self._programs.clear()
[docs] def name_to_index(self, name: str) -> int: assert self._programs[name].name == name return self._programs[name].selection_index
[docs] def to_seqc_program(self) -> str: lines = [] for const_name, const_val in self.GLOBAL_CONSTS.items(): if isinstance(const_val, (int, str)): const_repr = str(const_val) else: const_repr = const_val.to_seqc() lines.append('const {const_name} = {const_repr};'.format(const_name=const_name, const_repr=const_repr)) lines.append(self._waveform_memory.waveform_declaration()) lines.append('\n//function used by manually triggered programs') lines.append(self.SOFTWARE_WAIT_FOR_TRIGGER_FUNCTION_DEFINITION) replacements = self._waveform_memory.waveform_name_replacements() lines.append('\n// program definitions') for program_name, program in self.programs.items(): program_function_name = self.PROGRAM_FUNCTION_NAME_TEMPLATE.format(program_name=program_name) lines.append('void {program_function_name}() {{'.format(program_function_name=program_function_name)) lines.append(replace_multiple(program.seqc_source, replacements)) lines.append('}\n') lines.append(self.INIT_PROGRAM_SWITCH) lines.append('\n//runtime block') lines.append('while (true) {') lines.append(' // read program selection value') lines.append(' prog_sel = getUserReg(PROG_SEL_REGISTER);') lines.append(' if (!(prog_sel & NO_RESET_MASK)) setUserReg(PROG_SEL_REGISTER, 0);') lines.append(' prog_sel = prog_sel & PROG_SEL_MASK;') lines.append(' ') lines.append(' switch (prog_sel) {') for program_name, program_entry in self.programs.items(): program_function_name = self.PROGRAM_FUNCTION_NAME_TEMPLATE.format(program_name=program_name) lines.append(' case {selection_index}:'.format(selection_index=program_entry.selection_index)) lines.append(' {program_function_name}();'.format(program_function_name=program_function_name)) lines.append(' waitWave();') lines.append(' default:') lines.append(' wait(IDLE_WAIT_CYCLES);') lines.append(' }') lines.append('}') return '\n'.join(lines)
def find_sharable_waveforms(node_cluster: Sequence['SEQCNode']) -> Optional[Sequence[bool]]: """Expects nodes to have a compatible stepping TODO: encode in type system? """ waveform_playbacks = list(node_cluster[0].iter_waveform_playbacks()) candidates = [True] * len(waveform_playbacks) for node in itertools.islice(node_cluster, 1, None): candidates_left = False for idx, (wf, node_wf) in enumerate(zip(waveform_playbacks, node.iter_waveform_playbacks())): if candidates[idx]: candidates[idx] = wf == node_wf candidates_left = candidates_left or candidates[idx] if not candidates_left: return None return candidates def mark_sharable_waveforms(node_cluster: Sequence['SEQCNode'], sharable_waveforms: Sequence[bool]): for node in node_cluster: for sharable, wf_playback in zip(sharable_waveforms, node.iter_waveform_playbacks()): if sharable: wf_playback.shared = True def to_node_clusters(loop: Union[Sequence[Loop], Loop], loop_to_seqc_kwargs: dict) -> Sequence[Sequence['SEQCNode']]: """transform to seqc recursively noes and cluster them if they have compatible stepping""" assert len(loop) > 1 node_clusters = [] last_node = loop_to_seqc(loop[0], **loop_to_seqc_kwargs) current_nodes = [last_node] # compress all nodes in clusters of the same stepping for child in itertools.islice(loop, 1, None): current_node = loop_to_seqc(child, **loop_to_seqc_kwargs) if last_node.same_stepping(current_node): current_nodes.append(current_node) else: node_clusters.append(current_nodes) current_nodes = [current_node] last_node = current_node node_clusters.append(current_nodes) return node_clusters def loop_to_seqc(loop: Loop, min_repetitions_for_for_loop: int, min_repetitions_for_shared_wf: int, waveform_to_bin: Callable[[Waveform], BinaryWaveform], user_registers: UserRegisterManager) -> 'SEQCNode': assert min_repetitions_for_for_loop <= min_repetitions_for_shared_wf # At which point do we switch from indexed to shared if loop.is_leaf(): node = WaveformPlayback(waveform_to_bin(loop.waveform)) elif len(loop) == 1: node = loop_to_seqc(loop[0], min_repetitions_for_for_loop=min_repetitions_for_for_loop, min_repetitions_for_shared_wf=min_repetitions_for_shared_wf, waveform_to_bin=waveform_to_bin, user_registers=user_registers) else: node_clusters = to_node_clusters(loop, dict(min_repetitions_for_for_loop=min_repetitions_for_for_loop, min_repetitions_for_shared_wf=min_repetitions_for_shared_wf, waveform_to_bin=waveform_to_bin, user_registers=user_registers)) seqc_nodes = [] # identify shared waveforms in node clusters for node_cluster in node_clusters: if len(node_cluster) < min_repetitions_for_for_loop: seqc_nodes.extend(node_cluster) else: if len(node_cluster) >= min_repetitions_for_shared_wf: sharable_waveforms = find_sharable_waveforms(node_cluster) if sharable_waveforms: mark_sharable_waveforms(node_cluster, sharable_waveforms) seqc_nodes.append(SteppingRepeat(node_cluster)) node = Scope(seqc_nodes) if loop.volatile_repetition: register_var = user_registers.request(loop.repetition_definition) return Repeat(scope=node, repetition_count=register_var) elif loop.repetition_count != 1: return Repeat(scope=node, repetition_count=loop.repetition_count) else: return node class SEQCNode(metaclass=abc.ABCMeta): __slots__ = () INDENTATION = ' ' @abc.abstractmethod def samples(self) -> int: pass @abc.abstractmethod def same_stepping(self, other: 'SEQCNode'): pass @abc.abstractmethod def iter_waveform_playbacks(self) -> Iterator['WaveformPlayback']: pass def _get_single_indexed_playback(self) -> Optional['WaveformPlayback']: """Returns None if there is no or if there are more than one indexed playbacks""" # detect if there is only a single indexed playback single_indexed_playback = None for playback in self.iter_waveform_playbacks(): if not playback.shared: if single_indexed_playback is None: single_indexed_playback = playback else: break else: return single_indexed_playback return None @abc.abstractmethod def _visit_nodes(self, waveform_manager: ProgramWaveformManager): """push all concatenated waveforms in the waveform manager""" @abc.abstractmethod def to_source_code(self, waveform_manager: ProgramWaveformManager, node_name_generator: Iterator[str], line_prefix: str, pos_var_name: str, advance_pos_var: bool = True): """besides creating the source code, this function registers all needed waveforms to the program manager 1. shared waveforms 2. concatenated waveforms in the correct order Args: waveform_manager: node_name_generator: generates unique names of nodes line_prefix: pos_var_name: advance_pos_var: Indexed playback will not advance the position if set to False. This is used internally to optimize repeat statements with a single indexed playback. Returns: """ def __eq__(self, other): """Compare objects based on __slots__""" assert getattr(self, '__dict__', None) is None return type(self) == type(other) and all(getattr(self, attr) == getattr(other, attr) for base_class in inspect.getmro(type(self)) for attr in getattr(base_class, '__slots__', ())) class Scope(SEQCNode): """Sequence of nodes""" __slots__ = ('nodes',) def __init__(self, nodes: Sequence[SEQCNode] = ()): self.nodes = list(nodes) def samples(self): return sum(node.samples() for node in self.nodes) def iter_waveform_playbacks(self) -> Iterator['WaveformPlayback']: for node in self.nodes: yield from node.iter_waveform_playbacks() def same_stepping(self, other: 'Scope'): return (type(other) is Scope and len(self.nodes) == len(other.nodes) and all(n1.same_stepping(n2) for n1, n2 in zip(self.nodes, other.nodes))) def _visit_nodes(self, waveform_manager: ProgramWaveformManager): for node in self.nodes: node._visit_nodes(waveform_manager) def to_source_code(self, waveform_manager: ProgramWaveformManager, node_name_generator: Iterator[str], line_prefix: str, pos_var_name: str, advance_pos_var: bool = True): for node in self.nodes: yield from node.to_source_code(waveform_manager, line_prefix=line_prefix, pos_var_name=pos_var_name, node_name_generator=node_name_generator, advance_pos_var=advance_pos_var) class Repeat(SEQCNode): """""" __slots__ = ('repetition_count', 'scope') INITIAL_POSITION_NAME_TEMPLATE = 'init_pos_{node_name}' FOR_LOOP_NAME_TEMPLATE = 'idx_{node_name}' class _AdvanceStrategy: """describes what happens how this node interacts with the position variable""" INITIAL_RESET = 'initial_reset' POST_ADVANCE = 'post_advance' IGNORE = 'ignore' def __init__(self, repetition_count: Union[int, str], scope: SEQCNode): """ Args: repetition_count: A const integer value or a string that is expected to be a "var" scope: The repeated scope """ if isinstance(repetition_count, int): assert repetition_count > 1 else: assert isinstance(repetition_count, str) and repetition_count.isidentifier() self.repetition_count = repetition_count self.scope = scope def samples(self): return self.scope.samples() def same_stepping(self, other: 'Repeat'): return (type(self) == type(other) and self.repetition_count == other.repetition_count and self.scope.same_stepping(other.scope)) def iter_waveform_playbacks(self) -> Iterator['WaveformPlayback']: return self.scope.iter_waveform_playbacks() def _visit_nodes(self, waveform_manager: ProgramWaveformManager): self.scope._visit_nodes(waveform_manager) def _get_position_advance_strategy(self): """Deduct the optimal position advance strategy: There is more than one indexed playback -> position needs to be advanced during each iteration and set back to initial value at the begin of each new iteration There is exactly one indexed playback -> The position is not advanced in the body but needs to be advanced after all repetitions are done There is no indexed playback -> We do not care about the position at all """ self_samples = self.samples() if self_samples > 0: single_playback = self.scope._get_single_indexed_playback() if single_playback is None or single_playback.samples() != self_samples: # TODO: I am not sure whether the 'single_playback.samples() != self_samples' is necessary # there is more than one indexed playback return self._AdvanceStrategy.INITIAL_RESET else: # there is only a single indexed playback return self._AdvanceStrategy.POST_ADVANCE else: # there is no indexed playback return self._AdvanceStrategy.IGNORE def to_source_code(self, waveform_manager: ProgramWaveformManager, node_name_generator: Iterator[str], line_prefix: str, pos_var_name: str, advance_pos_var: bool = True): body_prefix = line_prefix + self.INDENTATION advance_strategy = self._get_position_advance_strategy() if advance_pos_var else self._AdvanceStrategy.IGNORE inner_advance_pos_var = advance_strategy == self._AdvanceStrategy.INITIAL_RESET def get_node_name(): """Helper to assert node name only generated when needed and only generated once""" if getattr(get_node_name, 'node_name', None) is None: get_node_name.node_name = next(node_name_generator) return get_node_name.node_name if advance_strategy == self._AdvanceStrategy.INITIAL_RESET: initial_position_name = self.INITIAL_POSITION_NAME_TEMPLATE.format(node_name=get_node_name()) # store initial position yield '{line_prefix}var {init_pos_name} = {pos_var_name};'.format(line_prefix=line_prefix, init_pos_name=initial_position_name, pos_var_name=pos_var_name) if isinstance(self.repetition_count, int): yield '{line_prefix}repeat({repetition_count}) {{'.format(line_prefix=line_prefix, repetition_count=self.repetition_count) else: # repeat requires a const-expression so we need to use a for loop for user reg vars assert isinstance(self.repetition_count, str) loop_var = self.FOR_LOOP_NAME_TEMPLATE.format(node_name=get_node_name()) yield '{line_prefix}var {loop_var};'.format(line_prefix=line_prefix, loop_var=loop_var) yield ('{line_prefix}for({loop_var} = 0; ' '{loop_var} < {repetition_count}; ' '{loop_var} = {loop_var} + 1) {{').format(line_prefix=line_prefix, loop_var=loop_var, repetition_count=self.repetition_count) if advance_strategy == self._AdvanceStrategy.INITIAL_RESET: yield ('{body_prefix}{pos_var_name} = {init_pos_name};' '').format(body_prefix=body_prefix, pos_var_name=pos_var_name, init_pos_name=initial_position_name) yield from self.scope.to_source_code(waveform_manager, line_prefix=body_prefix, pos_var_name=pos_var_name, node_name_generator=node_name_generator, advance_pos_var=inner_advance_pos_var) yield '{line_prefix}}}'.format(line_prefix=line_prefix) if advance_strategy == self._AdvanceStrategy.POST_ADVANCE: yield '{line_prefix}{pos_var_name} = {pos_var_name} + {samples};'.format(line_prefix=line_prefix, pos_var_name=pos_var_name, samples=self.samples()) class SteppingRepeat(SEQCNode): STEPPING_REPEAT_COMMENT = ' // stepping repeat' __slots__ = ('node_cluster',) def __init__(self, node_cluster: Sequence[SEQCNode]): self.node_cluster = node_cluster def samples(self) -> int: return self.repetition_count * self.node_cluster[0].samples() @property def repetition_count(self): return len(self.node_cluster) def iter_waveform_playbacks(self) -> Iterator['WaveformPlayback']: for node in self.node_cluster: yield from node.iter_waveform_playbacks() def same_stepping(self, other: 'SteppingRepeat'): return (type(other) is SteppingRepeat and len(self.node_cluster) == len(other.node_cluster) and self.node_cluster[0].same_stepping(other.node_cluster[0])) def _visit_nodes(self, waveform_manager: ProgramWaveformManager): for node in self.node_cluster: node._visit_nodes(waveform_manager) def to_source_code(self, waveform_manager: ProgramWaveformManager, node_name_generator: Iterator[str], line_prefix: str, pos_var_name: str, advance_pos_var: bool = True): body_prefix = line_prefix + self.INDENTATION repeat_open = '{line_prefix}repeat({repetition_count}) {{' + self.STEPPING_REPEAT_COMMENT yield repeat_open.format(line_prefix=line_prefix, repetition_count=self.repetition_count) yield from self.node_cluster[0].to_source_code(waveform_manager, line_prefix=body_prefix, pos_var_name=pos_var_name, node_name_generator=node_name_generator, advance_pos_var=advance_pos_var) # register remaining concatenated waveforms for node in itertools.islice(self.node_cluster, 1, None): node._visit_nodes(waveform_manager) yield '{line_prefix}}}'.format(line_prefix=line_prefix) class WaveformPlayback(SEQCNode): ADVANCE_DISABLED_COMMENT = ' // advance disabled do to parent repetition' __slots__ = ('waveform', 'shared') def __init__(self, waveform: BinaryWaveform, shared: bool = False): self.waveform = waveform self.shared = shared def samples(self): if self.shared: return 0 else: return len(self.waveform) def same_stepping(self, other: 'WaveformPlayback'): same_type = type(self) is type(other) and self.shared == other.shared if self.shared: return same_type and self.waveform == other.waveform else: return same_type and len(self.waveform) == len(other.waveform) def iter_waveform_playbacks(self) -> Iterator['WaveformPlayback']: yield self def _visit_nodes(self, waveform_manager: ProgramWaveformManager): if not self.shared: waveform_manager.request_concatenated(self.waveform) def to_source_code(self, waveform_manager: ProgramWaveformManager, node_name_generator: Iterator[str], line_prefix: str, pos_var_name: str, advance_pos_var: bool = True): if self.shared: yield '{line_prefix}playWave({waveform});'.format(waveform=waveform_manager.request_shared(self.waveform), line_prefix=line_prefix) else: wf_name = waveform_manager.request_concatenated(self.waveform) wf_len = len(self.waveform) play_cmd = '{line_prefix}playWaveIndexed({wf_name}, {pos_var_name}, {wf_len});' if advance_pos_var: advance_cmd = ' {pos_var_name} = {pos_var_name} + {wf_len};' else: advance_cmd = self.ADVANCE_DISABLED_COMMENT yield (play_cmd + advance_cmd).format(wf_name=wf_name, wf_len=wf_len, pos_var_name=pos_var_name, line_prefix=line_prefix)