"""Propagator classes for fixed output blocks."""
from jsonargparse import Namespace
from typing import List, Union
from .base import BasePropagator, get_shape, create_shape
[docs]class AddFixedPropagator(BasePropagator):
"""Propagator for blocks that adds fixed dimensions."""
fixed_dims = 1
[docs] def __init__(self, block_class: str, fixed_dims: int = 1):
"""Initializer for AddFixedPropagator instance.
Args:
block_class: The name of the block class being propagated.
fixed_dims: Number of fixed dimensions.
Raises:
ValueError: If fixed_dims not int > 0.
"""
super().__init__(block_class)
if not isinstance(fixed_dims, int) or not fixed_dims > 0:
raise ValueError(f'{type(self).__name__} requires fixed_dims to be an int > 0.')
self.fixed_dims = fixed_dims
[docs] def propagate(self, from_blocks: List[Namespace], block: Namespace):
"""Method that propagates shapes to a block.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
"""
from_shape = get_shape('out', from_blocks[0])
if self.fixed_dims == 1:
to_shape = from_shape + [block.output_feats]
else:
to_shape = from_shape + block.output_feats
block._shape = create_shape(from_shape, to_shape)
[docs]class FixedOutputPropagator(BasePropagator):
"""Propagator for fixed output size blocks."""
num_input_blocks = 1
unfixed_dims = 'any'
output_feats_dims = 1
[docs] def __init__(
self,
block_class: str,
unfixed_dims: Union[int, str] = 'any',
fixed_dims: int = 1,
):
"""Initializer for FixedOutputPropagator instance.
Args:
block_class: The name of the block class being propagated.
unfixed_dims: Number of unfixed dimensions.
fixed_dims: Number of fixed dimensions.
Raises:
ValueError: If fixed_dims not int > 0.
ValueError: If unfixed_dims not "any" or int > 0.
"""
super().__init__(block_class)
if not ((isinstance(unfixed_dims, int) and unfixed_dims > 0) or unfixed_dims == 'any'):
raise ValueError(f'{type(self).__name__} requires unfixed_dims to be "any" or an int > 0.')
if not isinstance(fixed_dims, int) or not fixed_dims > 0:
raise ValueError(f'{type(self).__name__} requires fixed_dims to be an int > 0.')
self.unfixed_dims = unfixed_dims
self.output_feats_dims = fixed_dims
[docs] def initial_checks(self, from_blocks: List[Namespace], block: Namespace):
"""Method that does some initial checks before propagation.
Calls the base class checks and makes sure that the input shape has at
least (fixed_dims+1) dimensions if unfixed_dims=="any" or exactly
(fixed_dims+fixed_dims) dimensions if unfixed_dims is int.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
Raises:
ValueError: When fixed_dims and unfixed_dims do not agree with from_block[0]._shape.
"""
super().initial_checks(from_blocks, block)
from_shape = get_shape('out', from_blocks[0])
msg = (f'{block._class} propagator requires input shape to have %s %d dimensions, but '
f'block[id={from_blocks[0]._id}] -> block[id={block._id}] has {len(from_shape)}.')
if self.unfixed_dims == 'any' and len(from_shape) < self.output_feats_dims:
raise ValueError(msg % ('at least', self.output_feats_dims))
if isinstance(self.unfixed_dims, int) and len(from_shape) != self.output_feats_dims+self.unfixed_dims:
raise ValueError(msg % ('exactly', self.output_feats_dims+self.unfixed_dims))
[docs] def propagate(self, from_blocks: List[Namespace], block: Namespace):
"""Method that propagates shapes to a block.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
"""
from_shape = get_shape('out', from_blocks[0])
if self.output_feats_dims == 1:
to_shape = from_shape[0:-self.output_feats_dims] + [block.output_feats]
else:
to_shape = from_shape[0:-self.output_feats_dims] + block.output_feats
block._shape = create_shape(from_shape, to_shape)