Source code for narchi.sympy

"""Functions for symbolic operations."""

import re
import sympy
from sympy.core import numbers
from typing import List, Union
from .schemas import variable_pattern


variable_regex = re.compile('^'+variable_pattern+'$')


[docs]def is_valid_dim(value): """Checks whether value is an int > 0 or str that follows variable_regex.pattern.""" if isinstance(value, int) or variable_regex.match(str(value)): if isinstance(value, int) and value <= 0: return False return True return False
[docs]def sympify_variable(value): """Returns the sympyfied object for the given value.""" var_match = variable_regex.match(str(value)) if var_match: value_sympy = sympy.sympify(var_match[1]) elif isinstance(value, (int, str)): value_sympy = sympy.sympify(value) else: raise ValueError(f'Expected input to be an int or a valid sympy expression or a string with ' f'pattern: {variable_regex.pattern}, but got {value}.') return value_sympy
[docs]def get_nonrational_variable(value): """Returns either an int or a string variable.""" if isinstance(value, numbers.Integer): return int(value) elif isinstance(value, numbers.Rational): raise ValueError(f'Obtained a rational result: {value}.') return '<<variable:'+str(value).replace(' ', '')+'>>'
[docs]def variable_operate(value: Union[str, int], operation: Union[str, int]) -> Union[str, int]: """Performs a symbolic operation on a given value. Args: value: The value to operate on, either an int or a variable, e.g. "<<variable:W/2+H/4>>". operation: Operation to apply on value, either int or expression, e.g. "__input__/3". Returns: The result of the operation. Raises: ValueError: * If operation is not int nor a valid expression. * If value is not an int or a string that follows variable_regex.pattern. * If value is not a valid expression or contains "__input__" as a free symbol. """ ## Handle operation ## if isinstance(operation, int): return operation elif not isinstance(operation, str): raise ValueError('Expected operation to be an int or a string.') operation_sympy = sympify_variable(operation) ## Handle input ## value_sympy = sympify_variable(value) value_symbs = {str(s) for s in value_sympy.free_symbols} if '__input__' in value_symbs: raise ValueError(f'Value must not contain "__input__" as a free symbol, but got {input}.') ## Substitute input into operation ## output_sympy = operation_sympy.subs({'__input__': value_sympy}) ## Return operation result ## return get_nonrational_variable(output_sympy)
[docs]def variables_aggregate(values: List[Union[str, int]], operation: str) -> Union[str, int]: """Performs a symbolic aggregation operation over all input values. Args: values: List of values to operate on. operation: One of '+'=sum, '*'=prod. Returns: The result of the operation. Raises: ValueError: If any value is not an int or a string that follows variable_regex.pattern. """ operations = {'+', '*'} if operation not in operations: raise ValueError(f'Expected operation to be one of {operations}, got {operation}.') if not isinstance(values, list) or len(values) < 1 or not all(isinstance(v, (str, int)) for v in values): raise ValueError(f'Expected values to be a list containing int or str elements, got {values}.') value = values[0] for num in range(1, len(values)): value_sympy = sympify_variable(value) value = variable_operate(values[num], f'__input__{operation}({value_sympy})') return value
[docs]def sum(values: List[Union[str, int]]) -> Union[str, int]: """Performs a symbolic sum of all input values. Args: values: List of values to operate on. Returns: The result of the operation. Raises: ValueError: If any value is not an int or a string that follows variable_regex.pattern. """ return variables_aggregate(values, '+')
[docs]def prod(values: List[Union[str, int]]) -> Union[str, int]: """Performs a symbolic product of all input values. Args: values: List of values to operate on. Returns: The result of the operation. Raises: ValueError: If any value is not an int or a string that follows variable_regex.pattern. """ return variables_aggregate(values, '*')
[docs]def divide(numerator: Union[str, int], denominator: Union[str, int]) -> Union[str, int]: """Performs a symbolic division. Args: numerator: Value for numerator. denominator: Value for denominator. Returns: The result of the operation. Raises: ValueError: If any value is not an int or a string that follows variable_regex.pattern. """ numerator = sympify_variable(numerator) denominator = sympify_variable(denominator) result = numerator/denominator return get_nonrational_variable(result)
[docs]def conv_out_length(length: Union[str, int], kernel: int, stride: int, padding: int, dilation: int) -> Union[str, int]: """Performs a symbolic calculation of the output length of a convolution. Args: length: Length of the input, either an int or a variable. kernel: Size of the kernel in the direction of length. stride: Stride size in the direction of the length. padding: Padding added at both sides in the direction of the length. dilation: Dilation size in the direction of the length. Returns: The result of the operation. """ operation_sympy = sympify_variable('1+(length+2*padding-dilation*(kernel-1)-1)/stride') output_sympy = operation_sympy.subs({'length': sympify_variable(length), 'kernel': sympify_variable(kernel), 'stride': sympify_variable(stride), 'padding': sympify_variable(padding), 'dilation': sympify_variable(dilation)}) frac_sympy = output_sympy.subs({s: 0 for s in output_sympy.free_symbols}) output_sympy = output_sympy - frac_sympy + numbers.Integer(frac_sympy) return get_nonrational_variable(output_sympy)