Source code for narchi.propagators.concat

"""Propagator classes for concatenating."""

from jsonargparse import Namespace
from typing import List
from .base import BasePropagator, get_shape, create_shape
from ..sympy import sum


[docs]class ConcatenatePropagator(BasePropagator): """Propagator for concatenating along a given dimension.""" num_input_blocks = '>1'
[docs] def initial_checks(self, from_blocks: List[Namespace], block: Namespace): """Method that does some initial checks before propagation. Calls the base class checks and makes sure that the dim attribute is valid and agrees with the input dimensions. Args: from_blocks: The input blocks. block: The block to propagate its shapes. Raises: ValueError: When block does not have a valid dim attribute that agrees with input dimensions. """ super().initial_checks(from_blocks, block) shape_0 = get_shape('out', from_blocks[0]) dim = block.dim if block.dim >= 0 else len(shape_0)+block.dim if dim < 0 or dim >= len(shape_0): raise ValueError(f'Value of dim attribute ({block.dim}) in block[id={block._id}] does not ' f'agree with the input dimensions coming from block[id={from_blocks[0]._id}].') for n in range(1, len(from_blocks)): shape_n = get_shape('out', from_blocks[n]) if len(shape_0) != len(shape_n) or \ any(shape_0[k] != shape_n[k] for k in range(len(shape_0)) if k != dim): raise ValueError(f'{self.block_class} expects all inputs to have the same shape except along ' f'the concatenating dimension, differs for block[id={from_blocks[n]._id}] ' f'connecting to block[id={block._id}], {shape_0} vs. {shape_n}.')
[docs] def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. """ shape_in = list(get_shape('out', from_blocks[0])) shape_in[block.dim] = None shape_out = list(shape_in) shape_out[block.dim] = sum([get_shape('out', b)[block.dim] for b in from_blocks]) block._shape = create_shape(shape_in, shape_out)