Source code for narchi.propagators.conv

"""Propagator classes for convolution blocks."""

from jsonargparse import Namespace
from typing import List
from .base import BasePropagator, get_shape, create_shape, set_shape_dim, check_output_feats_dims
from ..sympy import conv_out_length
from ..schemas import auto_tag


[docs]class ConvPropagator(BasePropagator): """Propagator for convolution style blocks.""" num_input_blocks = 1 num_features_source = 'output_feats' conv_dims = None
[docs] def __init__(self, block_class: str, conv_dims: int): """Initializer for ConvPropagator instance. Args: block_class: The name of the block class being propagated. conv_dims: Number of dimensions for the convolution. Raises: ValueError: If conv_dims not int > 0. """ super().__init__(block_class) valid_num_features_source = {'output_feats', 'from_shape'} if self.num_features_source not in valid_num_features_source: raise ValueError(f'{type(self).__name__} only allows num_features_source to be one of {valid_num_features_source}.') if not isinstance(conv_dims, int) or conv_dims < 1: raise ValueError(f'{type(self).__name__} only allows conv_dims to be an int > 0.') self.conv_dims = conv_dims
[docs] def initial_checks(self, from_blocks: List[Namespace], block: Namespace): """Method that does some initial checks before propagation. Calls the base class checks and makes sure that the input shape agrees with the convolution dimensions. Args: from_blocks: The input blocks. block: The block to propagate its shapes. Raises: ValueError: When conv_dims does not agree with from_block[0]._shape. """ super().initial_checks(from_blocks, block) shape_in = get_shape('out', from_blocks[0]) if len(shape_in)-1 != self.conv_dims: raise ValueError(f'{block._class} blocks require input shape to have {self.conv_dims+1} dimensions, ' f'but got {shape_in}.')
[docs] def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. Raises: ValueError: When block.output_feats not valid. NotImplementedError: If num_features_source is not one of {"from_shape", "output_feats"}. """ ## Set default values ## kernel = block.kernel_size stride = block.stride if hasattr(block, 'stride') else 1 padding = block.padding if hasattr(block, 'padding') else 0 dilation = block.dilation if hasattr(block, 'dilation') else 1 ## Initialize block._shape ## auto_dims = [auto_tag for _ in range(self.conv_dims)] from_shape = get_shape('out', from_blocks[0]) if self.num_features_source == 'from_shape': block._shape = create_shape(from_shape, [from_shape[0]]+auto_dims) elif self.num_features_source == 'output_feats': check_output_feats_dims(1, self.block_class, block) block._shape = create_shape(from_shape, [block.output_feats]+auto_dims) ## Calculate and set <<auto>> output dimensions ## for dim, val in enumerate(get_shape('out', block)): if val == auto_tag: in_length = get_shape('in', block)[dim] out_length = conv_out_length(in_length, kernel, stride, padding, dilation) set_shape_dim('out', block, dim, out_length)
[docs]class PoolPropagator(ConvPropagator): """Propagator for pooling style blocks.""" num_features_source = 'from_shape'