"""Propagator classes for recurrent blocks."""
from jsonargparse import Namespace
from typing import List
from .base import BasePropagator, get_shape, create_shape, set_shape_dim
from ..schemas import auto_tag
[docs]class RnnPropagator(BasePropagator):
"""Propagator for recurrent style blocks."""
num_input_blocks = 1
output_feats_dims = 1
[docs] def initial_checks(self, from_blocks: List[Namespace], block: Namespace):
"""Method that does some initial checks before propagation.
Calls the base class checks and makes sure that the input shape has two
dimensions and that block includes a valid output_feats attribute.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
Raises:
ValueError: When block.output_feats not valid.
ValueError: When len(from_block[0]._shape) != 2.
"""
super().initial_checks(from_blocks, block)
shape_in = get_shape('out', from_blocks[0])
if len(shape_in) != 2:
raise ValueError(f'{block._class} blocks require input shape to have 2 dimensions, but got {shape_in}.')
[docs] def propagate(self, from_blocks: List[Namespace], block: Namespace):
"""Method that propagates shapes to a block.
Args:
from_blocks: The input blocks.
block: The block to propagate its shapes.
Raises:
ValueError: When bidirectional==True and output_feats not even.
"""
## Set default values ##
if not hasattr(block, 'bidirectional'):
block.bidirectional = False
## Initialize block._shape ##
from_shape = get_shape('out', from_blocks[0])
output_feats = block.output_feats
block._shape = create_shape(from_shape, [auto_tag, output_feats])
## Set hidden size ##
if block.bidirectional and output_feats % 2 != 0:
raise ValueError(f'For bidirectional {block._class} expected output_feats to be even, but got {output_feats}.')
block.hidden_size = output_feats // (2 if block.bidirectional else 1)
## Propagate first dimension ##
set_shape_dim('out', block, 0, from_shape[0])