Source code for narchi.propagators.same

"""Propagator classes that preserve the same shape."""

from jsonargparse import Namespace
from typing import List
from .base import BasePropagator, get_shape, create_shape


[docs]class SameShapePropagator(BasePropagator): """Propagator for blocks in which the input and output shapes are the same."""
[docs] def initial_checks(self, from_blocks: List[Namespace], block: Namespace): """Method that does some initial checks before propagation. Calls the base class checks and if multi-input makes sure that all inputs have the same shape and if not multi-input makes sure that there is only a single input block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. Raises: ValueError: When multi_input==False and len(from_blocks) != 1. """ super().initial_checks(from_blocks, block) if self.num_input_blocks is not None: shape = get_shape('out', from_blocks[0]) if not all(shape == get_shape('out', b) for b in from_blocks[1:]): in_shapes = ', '.join('[id='+b._id+']='+str(get_shape('out', b)) for b in from_blocks) raise ValueError(f'block[id={block._id}] of type {self.block_class} requires all inputs to ' f'have the same shape, but got {in_shapes}.') else: if len(from_blocks) != 1: raise ValueError(f'block[id={block._id}] of type {self.block_class} only accepts one input ' f'block, but got {len(from_blocks)}.')
[docs] def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. """ block._shape = create_shape(get_shape('out', from_blocks[0]))
[docs]class SameShapesPropagator(SameShapePropagator): """Propagator for blocks that receive multiple inputs of the same shape and preserves this shape.""" num_input_blocks = '>1'
[docs]class SameShapeConsumeDimPropagator(SameShapePropagator): """Propagator for blocks in which the output shape is the same as input except the last which is consumed."""
[docs] def initial_checks(self, from_blocks: List[Namespace], block: Namespace): """Method that does some initial checks before propagation. Calls the base class checks and makes sure that the input has more than one dimension. Args: from_blocks: The input blocks. block: The block to propagate its shapes. Raises: ValueError: When len(input_shape) < 2. """ super().initial_checks(from_blocks, block) if len(get_shape('out', from_blocks[0])) < 2: raise ValueError(f'block[id={block._id}] of type {self.block_class} requires input to have ' f'at least two dimensions.')
[docs] def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. """ block._shape = create_shape(get_shape('out', from_blocks[0]), get_shape('out', from_blocks[0])[:-1])