Source code for narchi.propagators.reshape

"""Propagator classes for reshaping."""

from jsonargparse import Namespace
from jsonargparse import namespace_to_dict as n2d
from typing import List
from .base import BasePropagator, get_shape, create_shape
from ..sympy import prod, divide
from ..schemas import auto_tag, reshape_validator


[docs]def check_reshape_spec(reshape_spec): """Checks that reshape_spec is valid according to schema, indexes range is valid and there is at most one <<auto>> in each unflatten.""" reshape_validator.validate(reshape_spec) idxs = [] if reshape_spec != 'flatten': for val in reshape_spec: if isinstance(val, (int, str)): idxs.append(val) elif isinstance(val, list): idxs.extend([x for x in val]) else: idx = next(iter(val.keys())) idxs.append(int(idx)) if sum([x == auto_tag for x in val[idx]]) > 1: raise ValueError(f'At most one {auto_tag} is allowed in unflatten definition ({val[idx]}).') if sorted(idxs) != list(range(len(idxs))): raise ValueError(f'Invalid indexes range ({sorted(idxs)}) in reshape_spec.') return idxs
[docs]def norm_reshape_spec(reshape_spec): """Converts elements of a reshape_spec from Namespace to dict.""" if isinstance(reshape_spec, str): return reshape_spec return [n2d(x) if isinstance(x, Namespace) else x for x in reshape_spec]
[docs]class ReshapePropagator(BasePropagator): """Propagator for reshapping which could involve any of: permute, flatten and unflatten.""" num_input_blocks = 1
[docs] def initial_checks(self, from_blocks: List[Namespace], block: Namespace): """Method that does some initial checks before propagation. Calls the base class checks and makes sure that the reshape_spec attribute is valid and agrees with the input dimensions. Args: from_blocks: The input blocks. block: The block to propagate its shapes. Raises: ValueError: When block does not have a valid reshape_spec attribute that agrees with input dimensions. """ super().initial_checks(from_blocks, block) if block.reshape_spec != 'flatten': reshape_spec = norm_reshape_spec(block.reshape_spec) try: idxs = check_reshape_spec(reshape_spec) except Exception as ex: raise ValueError(f'Invalid reshape_spec attribute in block[id={block._id}] :: {ex}') from ex shape_in = get_shape('out', from_blocks[0]) if len(idxs) != len(shape_in): raise ValueError(f'Number of dimensions indexes in reshape_spec attribute of block[id={block._id}] does ' f'not agree with the input dimensions coming from block[id={from_blocks[0]._id}].')
[docs] def propagate(self, from_blocks: List[Namespace], block: Namespace): """Method that propagates shapes to a block. Args: from_blocks: The input blocks. block: The block to propagate its shapes. """ shape_in = get_shape('out', from_blocks[0]) shape_out = [] if block.reshape_spec == 'flatten': reshape_spec = [[n for n in range(len(shape_in))]] else: reshape_spec = norm_reshape_spec(block.reshape_spec) for val in reshape_spec: if isinstance(val, int): shape_out.append(shape_in[val]) elif isinstance(val, list): shape_out.append(prod([shape_in[x] for x in val])) elif isinstance(val, dict): idx = next(iter(val.keys())) in_dim = shape_in[int(idx)] dims = val[idx] if any(x == auto_tag for x in dims): auto_idx = dims.index(auto_tag) nonauto = prod([x for x in dims if x != auto_tag]) dims[auto_idx] = divide(in_dim, nonauto) shape_out.extend(dims) block._shape = create_shape(shape_in, shape_out)