Source code for qupulse.expressions.wrapper

"""This module contains the function :py:``make_wrappers`` to define wrapper classes for expression protocol implementations
which only implements methods of the protocol.
It is used for finding code that relies on expression implementation details."""

import math
from typing import Sequence, Any, Mapping, Union, Tuple
from numbers import Real

import numpy as np

from qupulse.expressions import protocol, sympy


[docs]def make_wrappers(expr: type, expr_scalar: type, expr_vector: type) -> Tuple[type, type, type]: """Create wrappers for expression base, scalar and vector types that only expose the methods defined in the corresponding expression protocol classes. The vector is currently not implemented. Args: expr: Expression base type of the implementation expr_scalar: Expression scalar type of the implementation expr_vector: Expression vector type of the implementation Returns: A tuple of (base, scalar, vector) types that wrap the given types. """ class ExpressionWrapper(protocol.Expression): def __init__(self, x): self._wrapped: protocol.Expression = expr(x) @classmethod def make(cls, expression_or_dict, numpy_evaluation=None) -> 'ExpressionWrapper': return cls(expression_or_dict) @property def underlying_expression(self) -> Any: return self._wrapped.underlying_expression def __hash__(self) -> int: return hash(self._wrapped) def __eq__(self, other): return self._wrapped == getattr(other, '_wrapped', other) @property def variables(self) -> Sequence[str]: return self._wrapped.variables def evaluate_in_scope(self, scope: Mapping) -> Union[Real, np.ndarray]: return self._wrapped.evaluate_in_scope(scope) def evaluate_symbolic(self, substitutions: Mapping[str, Any]) -> 'ExpressionWrapper': """Substitute a part of the expression for another""" return ExpressionWrapper(self._wrapped.evaluate_symbolic(substitutions)) def evaluate_time_dependent(self, scope: Mapping) -> Union['Expression', Real, np.ndarray]: """Evaluate to a time dependent expression or a constant.""" return self._wrapped.evaluate_time_dependent(scope) def get_serialization_data(self): return self._wrapped.get_serialization_data() class ExpressionScalarWrapper(ExpressionWrapper, protocol.ExpressionScalar): def __init__(self, x): ExpressionWrapper.__init__(self, 0) self._wrapped: protocol.ExpressionScalar = expr_scalar(x) # Scalar def __add__(self, other): return ExpressionScalarWrapper(self._wrapped + getattr(other, '_wrapped', other)) def __sub__(self, other): return ExpressionScalarWrapper(self._wrapped - getattr(other, '_wrapped', other)) def __mul__(self, other): return ExpressionScalarWrapper(self._wrapped * getattr(other, '_wrapped', other)) def __truediv__(self, other): return ExpressionScalarWrapper(self._wrapped / getattr(other, '_wrapped', other)) def __floordiv__(self, other): return ExpressionScalarWrapper(self._wrapped // getattr(other, '_wrapped', other)) def __ceil__(self): return ExpressionScalarWrapper(math.ceil(self._wrapped)) def __floor__(self): return ExpressionScalarWrapper(math.floor(self._wrapped)) def __float__(self): return float(self._wrapped) def __int__(self): return int(self._wrapped) def __abs__(self): return ExpressionScalarWrapper(abs(self._wrapped)) # Ordered def __lt__(self, other): return self._wrapped < getattr(other, '_wrapped', other) def __le__(self, other): return self._wrapped <= getattr(other, '_wrapped', other) def __gt__(self, other): return self._wrapped > getattr(other, '_wrapped', other) def __ge__(self, other): return self._wrapped >= getattr(other, '_wrapped', other) class ExpressionVectorWrapper(ExpressionWrapper): pass return ExpressionWrapper, ExpressionScalarWrapper, ExpressionVectorWrapper