"""Propagator classes for convolution blocks."""
from jsonargparse import Namespace
from typing import List
from .base import BasePropagator, get_shape, create_shape, set_shape_dim, check_output_feats_dims
from ..sympy import conv_out_length
from ..schemas import auto_tag
[docs]class ConvPropagator(BasePropagator):
"""Propagator for convolution style blocks."""
num_input_blocks = 1
num_features_source = 'output_feats'
conv_dims = None
[docs] def __init__(self, block_class: str, conv_dims: int):
"""Initializer for ConvPropagator instance.
Args:
block_class: The name of the block class being propagated.
conv_dims: Number of dimensions for the convolution.
Raises:
ValueError: If conv_dims not int > 0.
"""
super().__init__(block_class)
valid_num_features_source = {'output_feats', 'from_shape'}
if self.num_features_source not in valid_num_features_source:
raise ValueError(f'{type(self).__name__} only allows num_features_source to be one of {valid_num_features_source}.')
if not isinstance(conv_dims, int) or conv_dims < 1:
raise ValueError(f'{type(self).__name__} only allows conv_dims to be an int > 0.')
self.conv_dims = conv_dims
[docs] def initial_checks(self, from_blocks: List[Namespace], block: Namespace):
"""Method that does some initial checks before propagation.
Calls the base class checks and makes sure that the input shape agrees
with the convolution dimensions.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
Raises:
ValueError: When conv_dims does not agree with from_block[0]._shape.
"""
super().initial_checks(from_blocks, block)
shape_in = get_shape('out', from_blocks[0])
if len(shape_in)-1 != self.conv_dims:
raise ValueError(f'{block._class} blocks require input shape to have {self.conv_dims+1} dimensions, '
f'but got {shape_in}.')
[docs] def propagate(self, from_blocks: List[Namespace], block: Namespace):
"""Method that propagates shapes to a block.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
Raises:
ValueError: When block.output_feats not valid.
NotImplementedError: If num_features_source is not one of {"from_shape", "output_feats"}.
"""
## Set default values ##
kernel = block.kernel_size
stride = block.stride if hasattr(block, 'stride') else 1
padding = block.padding if hasattr(block, 'padding') else 0
dilation = block.dilation if hasattr(block, 'dilation') else 1
## Initialize block._shape ##
auto_dims = [auto_tag for _ in range(self.conv_dims)]
from_shape = get_shape('out', from_blocks[0])
if self.num_features_source == 'from_shape':
block._shape = create_shape(from_shape, [from_shape[0]]+auto_dims)
elif self.num_features_source == 'output_feats':
check_output_feats_dims(1, self.block_class, block)
block._shape = create_shape(from_shape, [block.output_feats]+auto_dims)
## Calculate and set <<auto>> output dimensions ##
for dim, val in enumerate(get_shape('out', block)):
if val == auto_tag:
in_length = get_shape('in', block)[dim]
out_length = conv_out_length(in_length, kernel, stride, padding, dilation)
set_shape_dim('out', block, dim, out_length)
[docs]class PoolPropagator(ConvPropagator):
"""Propagator for pooling style blocks."""
num_features_source = 'from_shape'