Source code for narchi.module

"""Classes related to neural network module architectures."""

import os
import json
from copy import deepcopy
from jsonargparse import (
    ArgumentParser,
    Namespace,
    Path,
    get_config_read_mode,
    namespace_to_dict,
    dict_to_namespace,
    ActionConfigFile,
    ActionJsonnet,
    ActionJsonnetExtVars,
)
from jsonargparse.typing import Path_dw
from typing import List, Optional, Union
from .schemas import auto_tag, narchi_validator, propagated_validator
from .graph import parse_graph
from .sympy import sympify_variable
from .propagators.base import BasePropagator, get_shape, create_shape, shapes_agree
from .propagators.group import get_blocks_dict, propagate_shapes, add_ids_prefix
from .instantiators.common import import_object
from . import __version__


[docs]class ModuleArchitecture: """Class for instantiating ModuleArchitecture objects.""" path = None jsonnet = None architecture = None propagators = 'default' blocks = None topological_predecessors = None
[docs] @staticmethod def get_config_parser(): """Returns a ModuleArchitecture configuration parser.""" parser = ArgumentParser( error_handler=None, description=ModuleArchitecture.__doc__, version=__version__) parser.add_argument('--cfg', action=ActionConfigFile, help='Path to a configuration file.') # loading options # group_load = parser.add_argument_group('Loading related options') group_load.add_argument('--validate', default=True, type=bool, help='Whether to validate architecture against narchi schema.') group_load.add_argument('--propagate', default=True, type=bool, help='Whether to propagate shapes in architecture.') group_load.add_argument('--propagated', default=False, type=bool, help='Whether architecture has already been propagated.') group_load.add_argument('--propagators', help='Overrides default propagators.') group_load.add_argument('--ext_vars', action=ActionJsonnetExtVars(), help='External variables required to load jsonnet.') group_load.add_argument('--cwd', help='Current working directory to load inner referenced files. Default None uses ' 'directory of main architecture file.') group_load.add_argument('--parent_id', default='', help='Identifier of parent module.') # output options # group_out = parser.add_argument_group('Output related options') group_out.add_argument('--overwrite', default=False, type=bool, help='Whether to overwrite existing files.') group_out.add_argument('--outdir', default='.', type=Path_dw, help='Directory where to write output files.') group_out.add_argument('--save_json', default=False, type=bool, help='Whether to write the architecture (up to the last successful step: jsonnet load, ' 'schema validation, parsing) in json format to the output directory.') return parser
[docs] def __init__( self, architecture: Union[str, Path] = None, cfg: Union[str, dict, Namespace] = None, parser: ArgumentParser = None, ): """Initializer for ModuleArchitecture class. Args: architecture: Path to a jsonnet architecture file. cfg: Path to config file or config object. parser: Parser object in case it is an extension of get_config_parser(). """ if parser is None: parser = self.get_config_parser() self.parser = parser self.apply_config(cfg) if architecture is not None: self.load_architecture(architecture)
[docs] def apply_config(self, cfg: Union[str, dict, Namespace]): """Applies a configuration to the ModuleArchitecture instance. Args: cfg: Path to config file or config object. """ if cfg is None: self.cfg = self.parser.get_defaults() elif isinstance(cfg, (str, Path)): self.cfg_file = cfg self.cfg = self.parser.parse_path(cfg) elif isinstance(cfg, Namespace): self.parser.check_config(cfg) self.cfg = cfg elif isinstance(cfg, dict): cfg = dict(cfg) if 'propagators' in cfg and isinstance(cfg['propagators'], dict): self.propagators = cfg.pop('propagators') if not hasattr(self, 'cfg'): self.cfg = self.parser.parse_object(cfg) else: self.cfg = self.parser.parse_object(cfg, cfg_base=self.cfg, defaults=False) else: raise ValueError(f'Unexpected configuration object: {cfg}') if self.propagators == 'default': self.propagators = import_object('narchi.blocks.propagators')
[docs] def load_architecture(self, architecture: Optional[Union[str, Path]]): """Loads an architecture file. Args: architecture: Path to a jsonnet architecture file. """ self.path = None self.jsonnet = None self.architecture = None self.blocks = None self.topological_predecessors = None ## Initialize with given ModuleArchitecture ## if isinstance(architecture, ModuleArchitecture): self.path = architecture.path self.jsonnet = architecture.jsonnet self.blocks = architecture.blocks self.topological_predecessors = architecture.topological_predecessors self.cfg.propagated = architecture.cfg.propagated architecture = architecture.architecture ## Load jsonnet file or snippet ## if isinstance(architecture, (str, Path)): self.path = Path(architecture, mode=get_config_read_mode(), cwd=self.cfg.cwd) self.cfg.cwd = os.path.dirname(self.path()) self.jsonnet = self.path.get_content() architecture = ActionJsonnet(schema=None).parse(self.path, ext_vars=self.cfg.ext_vars) if not isinstance(architecture, Namespace): architecture = dict_to_namespace(architecture) if not hasattr(architecture, '_id'): architecture._id = os.path.splitext(os.path.basename(self.path()))[0] if not isinstance(architecture, Namespace): raise ValueError(f'{type(self).__name__} expected architecture to be either a path or a namespace.') self.architecture = architecture ## Validate prior to propagation ## self.validate() ## Check inputs and outputs independent of blocks ## isect_ids = set(b._id for b in architecture.blocks).intersection(b._id for b in architecture.inputs+architecture.outputs) if isect_ids: raise ValueError(f'{type(self).__name__} inputs/outputs not allowed to be blocks {isect_ids}.') ## Create dictionary of blocks ## if not self.blocks and all(hasattr(architecture, x) for x in ['inputs', 'outputs', 'blocks']): if self.cfg.parent_id: architecture._id = self.cfg.parent_id add_ids_prefix(architecture, architecture.inputs+architecture.outputs, skip_io=False) self.blocks = get_blocks_dict(architecture.inputs + architecture.blocks) ## Propagate shapes ## if self.cfg.propagate: if not self.cfg.propagated: self.propagate() elif self.topological_predecessors is None: self.topological_predecessors = parse_graph(architecture.inputs, architecture)
[docs] def validate(self): """Validates the architecture against the narchi or propagated schema.""" if not self.cfg.validate: return try: if self.cfg.propagated: propagated_validator.validate(namespace_to_dict(self.architecture)) else: narchi_validator.validate(namespace_to_dict(self.architecture)) except Exception as ex: self.write_json_outdir() source = 'Propagated' if self.cfg.propagated else 'Pre-propagated' raise type(ex)(f'{source} architecture failed to validate against schema :: {ex}') from ex
[docs] def propagate(self): """Propagates the shapes of the neural network module architecture.""" if self.cfg.propagated: raise RuntimeError(f'Not possible to propagate an already propagated {type(self).__name__}.') architecture = self.architecture ## Parse graph getting node mapping in topological order ## topological_predecessors = parse_graph(architecture.inputs, architecture) output_ids = {b._id for b in architecture.outputs} if next(reversed(topological_predecessors)) not in output_ids: raise ValueError(f'In module[id={architecture._id}] expected one of output nodes {output_ids} to be the last in the graph.') ## Propagate shapes for the architecture blocks ## try: propagate_shapes(self.blocks, topological_predecessors, propagators=self.propagators, ext_vars=self.cfg.ext_vars, cwd=self.cfg.cwd, skip_ids=output_ids) except Exception as ex: self.write_json_outdir() raise ex for output_block in architecture.outputs: ## Get pre-output blocks ## pre_output_block_id = next(v[0] for k, v in topological_predecessors.items() if k == output_block._id) try: pre_output_block = next(b for b in architecture.blocks if b._id == pre_output_block_id) except StopIteration as ex: block_ids = {b._id for b in architecture.blocks} raise ValueError(f'In module[id={architecture._id}] pre-output block[id={pre_output_block_id}] not found among ids={block_ids}.') from ex ## Automatic output dimensions ## for dim, val in enumerate(output_block._shape): if val == auto_tag: output_block._shape[dim] = get_shape('out', pre_output_block)[dim] ## Check that output shape agrees ## if not shapes_agree(pre_output_block, output_block): self.write_json_outdir() raise ValueError(f'In module[id={architecture._id}] pre-output block[id={pre_output_block._id}] and output ' f'shape do not agree: {pre_output_block._shape.out} vs. {output_block._shape}.') ## Update properties ## self.topological_predecessors = topological_predecessors self.cfg.propagated = True ## Set propagated shape ## in_shape = architecture.inputs[0]._shape out_shape = architecture.outputs[0]._shape architecture._shape = create_shape(in_shape, out_shape) ## Validate result ## self.validate() ## Write json file if requested ## self.write_json_outdir()
[docs] def write_json(self, json_path): """Writes the current state of the architecture in json format to the given path.""" with open(json_path if isinstance(json_path, str) else json_path(), 'w') as f: architecture = namespace_to_dict(self.architecture) f.write(json.dumps(architecture, indent=2, sort_keys=True, ensure_ascii=False))
def _check_overwrite(self, path): """Raises IOError if overwrite not set and path already exists.""" if not self.cfg.overwrite and os.path.isfile(path): raise IOError(f'Refusing to overwrite existing file: {path}')
[docs] def write_json_outdir(self): """Writes the current state of the architecture in to the configured output directory.""" if not self.cfg.save_json or self.cfg.outdir is None or not hasattr(self, 'architecture'): return outdir = self.cfg.outdir if isinstance(self.cfg.outdir, str) else self.cfg.outdir() out_path = os.path.join(outdir, f'{self.architecture._id}.json') self._check_overwrite(out_path) self.write_json(out_path)
[docs]class ModulePropagator(BasePropagator): """Propagator for complete modules.""" num_input_blocks = 1
[docs] def propagate( self, from_blocks: List[Namespace], block: Namespace, propagators: dict = None, ext_vars: Namespace = {}, cwd: str = None, ): """Method that propagates shapes through a module. Args: from_blocks: The input blocks. block: The block to propagate its shapes. propagators: Dictionary of propagators. ext_vars: External variables required to load jsonnet. cwd: Working directory to resolve relative paths. Raises: ValueError: If no propagator found for some block. """ block_ext_vars = deepcopy(ext_vars) if ext_vars is None: block_ext_vars = Namespace() elif isinstance(ext_vars, dict): block_ext_vars = Namespace(**block_ext_vars) if hasattr(block, '_ext_vars'): vars(block_ext_vars).update(vars(block._ext_vars)) cfg = {'ext_vars': block_ext_vars, 'cwd': cwd, 'parent_id': block._id, 'propagate': False, 'propagators': propagators} module = ModuleArchitecture(block._path, cfg=cfg) self.connect_input(from_blocks, block, module) module.propagate() block._shape = module.architecture._shape delattr(module.architecture, '_shape') block.architecture = module.architecture
[docs] @staticmethod def connect_input(from_blocks, block, module): """Checks fixed dimensions agree and replaces the modules's variable dimensions.""" from_shape = get_shape('out', from_blocks[0]) to_shape = module.architecture.inputs[0]._shape assert len(from_shape) == len(to_shape) for dim in range(len(to_shape)): from_dim = sympify_variable(from_shape[dim]) to_dim = sympify_variable(to_shape[dim]) if len(to_dim.free_symbols) > 0: to_shape[dim] = f'<<variable:{from_dim}>>' elif from_dim != to_dim: raise ValueError(f'Shape dim {dim} does not agree for block[id={from_blocks[0]._id}] connecting ' f'to block[id={block._id}].')