Source code for narchi.propagators.group

"""Propagator classes for groups of blocks."""

import re
import inspect
from jsonargparse import Namespace
from typing import Dict, List
from .base import BasePropagator, get_shape, create_shape
from ..graph import parse_graph
from ..schemas import id_separator


[docs]def get_blocks_dict(blocks: List[dict]) -> Dict[str, dict]: """Function that creates a dictionary of blocks using _id as keys. Args: blocks: List of blocks objects. Returns: Dictionary of blocks. """ blocks_dict = {} for block in blocks: if block._id in blocks_dict: raise ValueError(f'Duplicate block id: {block._id}.') blocks_dict[block._id] = block return blocks_dict
[docs]def add_ids_prefix(block, io_blocks, skip_io=True): """Adds to block id a prefix consisting of parent id and separator as defined in propagated schema.""" prefix = block._id + id_separator for num, subblock in enumerate(block.blocks): if hasattr(block, '_class') and block._class == 'Sequential' and not hasattr(subblock, '_id'): subblock._id = prefix + str(num) else: subblock._id = prefix + subblock._id if hasattr(block, 'input'): block.input = prefix + block.input if hasattr(block, 'output'): block.output = prefix + block.output if hasattr(block, 'inputs') and not skip_io: for node in block.inputs: node._id = prefix + node._id if hasattr(block, 'outputs') and not skip_io: for node in block.outputs: node._id = prefix + node._id if hasattr(block, 'graph'): skip_ids = set() if skip_io else {b._id for b in io_blocks} re_nodes = re.compile(' +-> +') for num, graph_line in enumerate(block.graph): nodes = re_nodes.split(graph_line) nodes = [n if n in skip_ids else prefix+n for n in nodes] block.graph[num] = ' -> '.join(nodes)
[docs]def propagate_shapes( blocks_dict: Dict[str, dict], topological_predecessors: Dict[str, List[str]], propagators: dict, ext_vars: dict, cwd: str, skip_ids: set = None, ): """Function that propagates shapes in blocks based on a connections mapping. Args: blocks_dict: Dictionary of blocks. topological_predecessors: Mapping of block IDs to its input blocks IDs. propagators: Dictionary of propagators. ext_vars: Dictionary of external variables required to load jsonnet. cwd: Working directory to resolve relative paths. skip_ids: Blocks that should be skipped in propagation. Raises: ValueError: If there graph references an undefined block. ValueError: If no propagator found for some block. """ if skip_ids is None: skip_ids = set() for node_to, nodes_from in topological_predecessors.items(): if node_to in skip_ids: continue from_blocks = [blocks_dict[n] for n in nodes_from] if node_to not in blocks_dict: block_ids = {k for k in blocks_dict.keys()} raise ValueError(f'Graph references block[id={node_to}] which is not found among ids={block_ids}.') block = blocks_dict[node_to] if block._class not in propagators: raise ValueError(f'No propagator found for block[id={block._id}] of type {block._class}.') propagator = propagators[block._class] func_param = {x.name for x in inspect.signature(propagator).parameters.values()} kwargs = {} if 'propagators' in func_param: kwargs['propagators'] = propagators if 'ext_vars' in func_param: kwargs['ext_vars'] = ext_vars if 'cwd' in func_param: kwargs['cwd'] = cwd propagator(from_blocks, block, **kwargs) return blocks_dict
[docs]class SequentialPropagator(BasePropagator): """Propagator for a sequence of blocks.""" num_input_blocks = 1
[docs] def propagate( self, from_blocks: List[Namespace], block: Namespace, propagators: dict, ext_vars: dict, cwd: str = None, ): """Method that propagates shapes in the given block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. propagators: Dictionary of propagators. ext_vars: Dictionary of external variables required to load jsonnet. cwd: Working directory to resolve relative paths. Raises: ValueError: If there are multiple blocks with the same id. ValueError: If no propagator found for some block. """ add_ids_prefix(block, from_blocks) blocks = get_blocks_dict(from_blocks + block.blocks) topological_predecessors = parse_graph(from_blocks, block) try: propagate_shapes(blocks, topological_predecessors, propagators=propagators, ext_vars=ext_vars, cwd=cwd) except Exception as ex: raise type(ex)(f'block[id={block._id}]: {ex}') from ex in_shape = get_shape('out', from_blocks[0]) out_shape = get_shape('out', block.blocks[-1]) block._shape = create_shape(in_shape, out_shape)
[docs]class GroupPropagator(SequentialPropagator): """Propagator for a sequence of blocks."""
[docs] def propagate( self, from_blocks: List[Namespace], block: Namespace, propagators: dict, ext_vars: dict, cwd: str = None, ): """Method that propagates shapes in the given block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. propagators: Dictionary of propagators. ext_vars: Dictionary of external variables required to load jsonnet. cwd: Working directory to resolve relative paths. Raises: ValueError: If there are multiple blocks with the same id. ValueError: If there graph references an undefined block. ValueError: If no propagator found for some block. """ add_ids_prefix(block, from_blocks) blocks = get_blocks_dict(from_blocks + block.blocks) topological_predecessors = parse_graph(from_blocks, block) try: propagate_shapes(blocks, topological_predecessors, propagators=propagators, ext_vars=ext_vars, cwd=cwd) except Exception as ex: raise type(ex)(f'block[id={block._id}]: {ex}') from ex in_shape = get_shape('out', from_blocks[0]) out_shape = get_shape('out', next(x for x in block.blocks if x._id==block.output)) block._shape = create_shape(in_shape, out_shape)