"""Base propagator class and related functions."""
import re
import inspect
from jsonargparse import Namespace, dict_to_namespace, namespace_to_dict
from copy import deepcopy
from typing import List
from ..schemas import auto_tag, block_validator
from ..sympy import is_valid_dim
gt_regex = re.compile('^>[0-9]+$')
[docs]def get_shape(key, shape):
"""Gets the shape list for a given key among {'in','out'}."""
if isinstance(shape, Namespace) and hasattr(shape, '_shape'):
shape = shape._shape
if isinstance(shape, Namespace):
shape = vars(shape)
if isinstance(shape, list):
return shape
return shape[key]
[docs]def create_shape(shape_in, shape_out=None):
"""Creates a shape namespace with 'in' and 'out' attributes and copied shape arrays."""
shape = {'in': deepcopy(shape_in),
'out': shape_in if shape_out is None else shape_out}
return dict_to_namespace(deepcopy(shape))
[docs]def set_shape_dim(key, shape, dim, val):
"""Sets a value for a given dimension, shape and key ('in' or 'out')."""
get_shape(key, shape)[dim] = val
[docs]def shapes_agree(shape_from, shape_to):
"""Checks whether the output shape from a block agrees with input shape of another block."""
return get_shape('out', shape_from) == get_shape('in', shape_to)
[docs]def shape_has_auto(shape):
"""Checks whether a shape has any <<auto>> values."""
if isinstance(shape, str):
shape = [shape]
if any([x == auto_tag for x in shape]):
return True
return False
[docs]def check_output_feats_dims(output_feats_dims, block_class, block):
"""Checks the output_feats attribute of a block."""
if output_feats_dims in {1, 2, 3}:
if not hasattr(block, 'output_feats'):
raise ValueError(f'{block_class} propagator expected block[id={block._id}] to include an output_feats attribute.')
if output_feats_dims == 1 and not is_valid_dim(block.output_feats):
raise ValueError(f'{block_class} propagator expected block[id={block._id}] output_feats to be a '
f'variable or an int larger than zero.')
if output_feats_dims > 1 and (not isinstance(block.output_feats, list) or not all(is_valid_dim(x) for x in block.output_feats)):
raise ValueError(f'{block_class} propagator expected block[id={block._id}] output_feats to be a '
f'list with {output_feats_dims} variables or ints larger than zero.')
[docs]class BasePropagator:
"""Base class for block shapes propagation."""
block_class = None
num_input_blocks = None
output_feats_dims = False
[docs] def __init__(self, block_class):
"""Initializer for BasePropagator instance.
Args:
block_class (str): The name of the block class being propagated.
"""
self.block_class = block_class
[docs] def initial_checks(self, from_blocks: List[Namespace], block: Namespace):
"""Method that does some initial checks before propagation.
Extensions of this method in derived classes should always call this
base method. This base method implements the following checks:
- That the block class is the same as the one expected by the
propagator.
- That the input shapes don't contain any <<auto>> values.
- If num_input_blocks is set and is an int, that there are exactly this
number of input blocks.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
Raises:
ValueError: If block fails to validate against schema.
ValueError: If block already has a _shape attribute.
ValueError: If block._class != block_class.
ValueError: If input shape not present, invalid or contains <<auto>>.
ValueError: If output_feats required by class and not present or invalid.
ValueError: If len(from_blocks) != num_input_blocks.
"""
try:
block_validator.validate(namespace_to_dict(block))
except Exception as ex:
block_id = block._id if hasattr(block, '_id') else 'None'
raise ValueError(f'Validation failed for block[id={block_id}] :: {ex}') from ex
if hasattr(block, '_shape'):
raise ValueError(f'Propagation only supported for blocks without a _shape attribute, '
f'found {block._shape} in block[id={block._id}].')
if block._class != self.block_class:
raise ValueError(f'Attempted to propagate block[id={block._id}] of class {block._class} using '
f'a {self.block_class} propagator.')
if not isinstance(from_blocks, list) or not all(isinstance(x, Namespace) for x in from_blocks):
raise ValueError(f'Expected from_blocks to be of type list[Namespace], not so for blocks '
f'connecting to block[id={block._id}].')
for from_block in from_blocks:
if not hasattr(from_block, '_shape'):
raise ValueError(f'{self.block_class} propagator expected from_block[id={from_block._id}] to '
f'include a _shape attribute.')
shape_in = get_shape('out', from_block)
if len(shape_in) < 1:
raise ValueError(f'Input block requires to have at least one dimension, zero'
f'found for block[id={from_block._id}] -> block[id={block._id}].')
if shape_has_auto(shape_in):
raise ValueError(f'Input block not allowed to have {auto_tag} values in shape, '
f'found for block[id={from_block._id}] -> block[id={block._id}].')
check_output_feats_dims(self.output_feats_dims, self.block_class, block)
if self.num_input_blocks is not None:
invalid = True
if isinstance(self.num_input_blocks, int) and len(from_blocks) == self.num_input_blocks:
invalid = False
elif gt_regex.match(str(self.num_input_blocks)) and len(from_blocks) > int(self.num_input_blocks[1:]):
invalid = False
if invalid:
raise ValueError(f'Blocks of class {self.block_class} only accept {self.num_input_blocks} input '
f'blocks, found {len(from_blocks)} for block[id={block._id}].')
[docs] def propagate(self, from_blocks: List[Namespace], block: Namespace):
"""Method that propagates shapes to a block.
This base method should be implemented by all derived classes.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
Raises:
NotImplementedError: Always.
"""
raise NotImplementedError('This method should be implemented by derived classes.')
[docs] def final_checks(self, from_blocks: List[Namespace], block: Namespace):
"""Method that checks for problems after shapes have been propagated.
This base method implements checking the output shape don't contain
<<auto>> values and if there is only a single from_block, that the
connecting shapes agree. Extensions of this method in derived classes
should always call this base one.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
"""
if shape_has_auto(get_shape('out', block)):
raise ValueError(f'Unexpectedly after propagation block has {auto_tag} values '
f'in output shape, found for block[id={block._id}].')
if len(from_blocks) == 1 and not shapes_agree(from_blocks[0], block):
raise ValueError(f'Shapes do not agree for block[id={from_blocks[0]._id}] connecting to block[id={block._id}].')
[docs] def __call__(
self,
from_blocks: List[Namespace],
block: Namespace,
propagators: dict = None,
ext_vars: dict = {},
cwd: str = None
):
"""Propagates shapes to the given block.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
propagators: Dictionary of propagators.
ext_vars: Dictionary of external variables required to load jsonnet.
cwd: Working directory to resolve relative paths.
"""
self.initial_checks(from_blocks, block)
func_param = {x.name for x in inspect.signature(self.propagate).parameters.values()}
kwargs = {}
if 'propagators' in func_param:
kwargs['propagators'] = propagators
if 'ext_vars' in func_param:
kwargs['ext_vars'] = ext_vars
if 'cwd' in func_param:
kwargs['cwd'] = cwd
self.propagate(from_blocks, block, **kwargs)
self.final_checks(from_blocks, block)