Source code for qupulse.expressions.sympy

"""
This module defines the class Expression to represent mathematical expression as well as
corresponding exception classes.
"""
import operator
from typing import Any, Dict, Union, Sequence, Callable, TypeVar, Type, Mapping
from numbers import Number
import warnings
import functools
import array
import itertools

import sympy
import numpy

from qupulse.serialization import AnonymousSerializable
from qupulse.utils.sympy import sympify, to_numpy, recursive_substitution, evaluate_lambdified,\
    get_most_simple_representation, get_variables, evaluate_lamdified_exact_rational
from qupulse.utils.types import TimeType

import qupulse.expressions

__all__ = ["Expression", "ExpressionScalar", "ExpressionVector"]


_ExpressionType = TypeVar('_ExpressionType', bound='Expression')


ALLOWED_NUMERIC_SCALAR_TYPES = (float, numpy.number, int, complex, bool, numpy.bool_, TimeType)


def _parse_evaluate_numeric(result) -> Union[Number, numpy.ndarray]:
    """Tries to parse the result as a scalar if possible. Falls back to an array otherwise.
    Raises:
        ValueError if scalar result is not parsable
    """
    allowed_scalar = ALLOWED_NUMERIC_SCALAR_TYPES

    if isinstance(result, allowed_scalar):
        # fast path for regular evaluations
        return result
    if isinstance(result, tuple):
        result, = result
    elif isinstance(result, numpy.ndarray):
        result = result[()]

    if isinstance(result, allowed_scalar):
        return result
    if isinstance(result, sympy.Float):
        return float(result)
    elif isinstance(result, sympy.Integer):
        return int(result)

    if isinstance(result, numpy.ndarray):
        # allow numeric vector values
        return _parse_evaluate_numeric_vector(result)
    raise ValueError("Non numeric result", result)


def _parse_evaluate_numeric_vector(vector_result: numpy.ndarray) -> numpy.ndarray:
    allowed_scalar = ALLOWED_NUMERIC_SCALAR_TYPES
    if not issubclass(vector_result.dtype.type, allowed_scalar):
        obj_types = set(map(type, vector_result.flat))
        if all(issubclass(obj_type, sympy.Integer) for obj_type in obj_types):
            vector_result = vector_result.astype(numpy.int64)
        elif all(issubclass(obj_type, (sympy.Integer, sympy.Float)) for obj_type in obj_types):
            vector_result = vector_result.astype(float)
        else:
            raise ValueError("Could not parse vector result", vector_result)
    return vector_result


def _flat_iter(arr):
    if len(arr.shape) > 1:
        for sub_arr in arr:
            yield from _flat_iter(sub_arr)
    else:
        yield from arr


class _ExpressionMeta(type):
    """Metaclass that forwards calls to Expression(...) to Expression.make(...) to make subclass objects"""
    def __call__(cls: Type[_ExpressionType], *args, **kwargs) -> _ExpressionType:
        if cls is Expression:
            return cls.make(*args, **kwargs)
        else:
            return type.__call__(cls, *args, **kwargs)


[docs]class Expression(AnonymousSerializable, metaclass=_ExpressionMeta): """Base class for expressions.""" def __init__(self, *args, **kwargs): self._expression_lambda = None def _parse_evaluate_numeric_arguments(self, eval_args: Mapping[str, Number]) -> Dict[str, Number]: try: return {v: eval_args[v] for v in self.variables} except KeyError as key_error: if type(key_error).__module__.startswith('qupulse'): # we forward qupulse errors, I down like this raise else: raise qupulse.expressions.ExpressionVariableMissingException(key_error.args[0], self) from key_error
[docs] def evaluate_in_scope(self, scope: Mapping) -> Union[Number, numpy.ndarray]: """Evaluate the expression by taking the variables from the given scope (typically of type Scope but it can be any mapping.) Args: scope: Returns: """ raise NotImplementedError("")
[docs] def evaluate_numeric(self, **kwargs) -> Union[Number, numpy.ndarray]: return self.evaluate_in_scope(kwargs)
def __float__(self): if self.variables: return NotImplemented else: e = self.evaluate_numeric() return float(e)
[docs] def evaluate_symbolic(self, substitutions: Mapping[Any, Any]) -> 'Expression': if len(substitutions)==0: return self return Expression.make(recursive_substitution(sympify(self.underlying_expression), substitutions))
def _evaluate_to_time_dependent(self, scope: Mapping) -> Union['Expression', Number, numpy.ndarray]: try: return self.evaluate_numeric(**scope, t=sympy.symbols('t')) except qupulse.expressions.NonNumericEvaluation as non_num: return ExpressionScalar(non_num.non_numeric_result) except TypeError: return self.evaluate_symbolic(scope) @property def variables(self) -> Sequence[str]: """ Get all free variables in the expression. Returns: A collection of all free variables occurring in the expression. """ raise NotImplementedError()
[docs] @classmethod def make(cls: Type[_ExpressionType], expression_or_dict, numpy_evaluation=None) -> Union['ExpressionScalar', 'ExpressionVector', _ExpressionType]: """Backward compatible expression generation""" if numpy_evaluation is not None: warnings.warn('numpy_evaluation keyword argument is deprecated and ignored.') if isinstance(expression_or_dict, dict): expression = expression_or_dict['expression'] elif isinstance(expression_or_dict, cls): return expression_or_dict else: expression = expression_or_dict if cls is Expression: if isinstance(expression, (list, tuple, numpy.ndarray, sympy.NDimArray, array.array)): return ExpressionVector(expression) else: return ExpressionScalar(expression) else: return cls(expression)
@property def underlying_expression(self) -> Union[sympy.Expr, numpy.ndarray]: raise NotImplementedError()
[docs]class ExpressionVector(Expression): """N-dimensional expression. TODO: write doc! TODO: write tests! """ sympify_vector = numpy.vectorize(sympify) def __init__(self, expression_vector: Sequence): super().__init__() if isinstance(expression_vector, sympy.NDimArray): expression_shape = expression_vector.shape expression_items = tuple(_flat_iter(expression_vector)) else: expression_ndarray = self.sympify_vector(expression_vector) expression_items = tuple(expression_ndarray.flat) expression_shape = expression_ndarray.shape self._expression_items = expression_items self._expression_shape = expression_shape self._lambdified_items = [None] * len(self._expression_items) variables = set(itertools.chain.from_iterable(map(get_variables, self._expression_items))) self._variables = tuple(sorted(variables)) @property def variables(self) -> Sequence[str]: return self._variables
[docs] def evaluate_in_scope(self, scope: Mapping) -> numpy.ndarray: parsed_kwargs = self._parse_evaluate_numeric_arguments(scope) flat_result = [] for idx, expr in enumerate(self._expression_items): result, self._lambdified_items[idx] = evaluate_lambdified(expr, self.variables, parsed_kwargs, lambdified=self._lambdified_items[idx]) flat_result.append(result) result = numpy.array(flat_result).reshape(self._expression_shape) try: return _parse_evaluate_numeric_vector(result) except ValueError as err: raise qupulse.expressions.NonNumericEvaluation(self, result, scope) from err
[docs] def get_serialization_data(self) -> Sequence[str]: serialized_items = list(map(get_most_simple_representation, self._expression_items)) if len(self._expression_shape) == 0: return serialized_items[0] elif len(self._expression_shape) == 1: return serialized_items else: return numpy.array(serialized_items).reshape(self._expression_shape).tolist()
def __getstate__(self): return self.get_serialization_data() def __setstate__(self, state): self.__init__(state) def __str__(self): return str(self.get_serialization_data()) def __repr__(self): return f'ExpressionVector({self.get_serialization_data()!r})' def _sympy_(self): return sympy.NDimArray(self.to_ndarray()) def __eq__(self, other): if not isinstance(other, Expression): try: other = Expression.make(other) except (ValueError, TypeError): return NotImplemented if isinstance(other, ExpressionScalar): return self._expression_shape in ((), (1,)) and self._expression_items[0] == other.sympified_expression else: return self._expression_shape == other._expression_shape and \ self._expression_items == other._expression_items def __hash__(self): if self._expression_shape in ((), (1,)): return hash(self._expression_items[0]) else: return hash((self._expression_items, self._expression_shape)) def __getitem__(self, item) -> Expression: if len(self._expression_shape) == 0: assert item == () expr, = self._expression_items return ExpressionScalar(expr) if len(self._expression_shape) == 1: return ExpressionScalar(self._expression_items[item]) else: return ExpressionVector(self.to_ndarray()[item])
[docs] def to_ndarray(self) -> numpy.ndarray: return numpy.array(self._expression_items).reshape(self._expression_shape)
@property def underlying_expression(self) -> numpy.ndarray: return self.to_ndarray()
[docs]class ExpressionScalar(Expression): """A scalar mathematical expression instantiated from a string representation. TODO: update doc! TODO: write tests! """
[docs] def __init__(self, ex: Union[str, Number, sympy.Expr]) -> None: """Create an Expression object. Receives the mathematical expression which shall be represented by the object as a string which will be parsed using py_expression_eval. For available operators, functions and constants see SymPy documentation Args: ex (string): The mathematical expression represented as a string """ super().__init__() if isinstance(ex, sympy.Expr): self._original_expression = None self._sympified_expression = ex self._variables = get_variables(self._sympified_expression) elif isinstance(ex, ExpressionScalar): self._original_expression = ex._original_expression self._sympified_expression = ex._sympified_expression self._variables = ex._variables elif isinstance(ex, (int, float)): if isinstance(ex, numpy.float64): ex = float(ex) self._original_expression = ex self._sympified_expression = sympify(ex) self._variables = () else: self._original_expression = ex self._sympified_expression = sympify(ex) self._variables = get_variables(self._sympified_expression) self._exact_rational_lambdified = None
def __float__(self): if isinstance(self._original_expression, float): return self._original_expression else: return super().__float__() @property def underlying_expression(self) -> sympy.Expr: return self._sympified_expression def __str__(self) -> str: return str(self._sympified_expression) def __repr__(self) -> str: if self._original_expression is None: return f"ExpressionScalar('{self._sympified_expression!r}')" else: return f"ExpressionScalar({self._original_expression!r})" def __format__(self, format_spec): if format_spec == '': return str(self) return format(float(self), format_spec) @property def variables(self) -> Sequence[str]: return self._variables @classmethod def _sympify(cls, other: Union['ExpressionScalar', Number, sympy.Expr]) -> sympy.Expr: return other._sympified_expression if isinstance(other, cls) else sympify(other) @classmethod def _extract_sympified(cls, other: Union['ExpressionScalar', Number, sympy.Expr]) \ -> Union['ExpressionScalar', Number, sympy.Expr]: return getattr(other, '_sympified_expression', other) def __lt__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> Union[bool, None]: result = self._sympified_expression < self._extract_sympified(other) return None if isinstance(result, sympy.Rel) else bool(result) def __gt__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> Union[bool, None]: result = self._sympified_expression > self._extract_sympified(other) return None if isinstance(result, sympy.Rel) else bool(result) def __ge__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> Union[bool, None]: result = self._sympified_expression >= self._extract_sympified(other) return None if isinstance(result, sympy.Rel) else bool(result) def __le__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> Union[bool, None]: result = self._sympified_expression <= self._extract_sympified(other) return None if isinstance(result, sympy.Rel) else bool(result) def __eq__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> bool: """Enable comparisons with Numbers""" # sympy's __eq__ checks for structural equality to be consistent regarding __hash__ so we do that too # see https://github.com/sympy/sympy/issues/18054#issuecomment-566198899 return self._sympified_expression == self._sympify(other) def __hash__(self) -> int: return hash(self._sympified_expression) def __add__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympified_expression.__add__(self._extract_sympified(other))) def __radd__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympify(other).__radd__(self._sympified_expression)) def __sub__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympified_expression.__sub__(self._extract_sympified(other))) def __rsub__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympified_expression.__rsub__(self._extract_sympified(other))) def __mul__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympified_expression.__mul__(self._extract_sympified(other))) def __rmul__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympified_expression.__rmul__(self._extract_sympified(other))) def __truediv__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympified_expression.__truediv__(self._extract_sympified(other))) def __rtruediv__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympified_expression.__rtruediv__(self._extract_sympified(other))) def __floordiv__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympified_expression.__floordiv__(self._extract_sympified(other))) def __rfloordiv__(self, other: Union['ExpressionScalar', Number, sympy.Expr]) -> 'ExpressionScalar': return self.make(self._sympified_expression.__rfloordiv__(self._extract_sympified(other))) def __neg__(self) -> 'ExpressionScalar': return self.make(self._sympified_expression.__neg__()) def __pos__(self): return self.make(self._sympified_expression.__pos__()) def _sympy_(self): return self._sympified_expression @property def original_expression(self) -> Union[str, Number]: if self._original_expression is None: return str(self._sympified_expression) else: return self._original_expression @property def sympified_expression(self) -> sympy.Expr: return self._sympified_expression
[docs] def get_serialization_data(self) -> Union[str, float, int]: serialized = get_most_simple_representation(self._sympified_expression) if isinstance(serialized, str): return self.original_expression else: return serialized
def __getstate__(self): return self.get_serialization_data() def __setstate__(self, state): self.__init__(state)
[docs] def is_nan(self) -> bool: return sympy.sympify('nan') == self._sympified_expression
[docs] def evaluate_with_exact_rationals(self, scope: Mapping) -> Union[Number, numpy.ndarray]: parsed_kwargs = self._parse_evaluate_numeric_arguments(scope) result, self._exact_rational_lambdified = evaluate_lamdified_exact_rational(self.sympified_expression, self.variables, parsed_kwargs, self._exact_rational_lambdified) try: return _parse_evaluate_numeric(result) except ValueError as err: raise qupulse.expressions.NonNumericEvaluation(self, result, scope) from err
[docs] def evaluate_in_scope(self, scope: Mapping) -> Union[Number, numpy.ndarray]: parsed_kwargs = self._parse_evaluate_numeric_arguments(scope) result, self._expression_lambda = evaluate_lambdified(self.underlying_expression, self.variables, parsed_kwargs, lambdified=self._expression_lambda) try: return _parse_evaluate_numeric(result) except ValueError as err: raise qupulse.expressions.NonNumericEvaluation(self, result, scope) from err