"""Classes related to rendering of architectures."""

import os
import re
import itertools
import textwrap
from importlib.util import find_spec
from jsonargparse import ActionJsonSchema, Namespace, namespace_to_dict, Path
from jsonargparse.typing import NonNegativeInt
from typing import Union
from .propagators.base import get_shape
from import get_blocks_dict, add_ids_prefix
from .module import ModuleArchitecture
from .graph import parse_graph
from .schemas import id_separator
from .sympy import sympify_variable

pygraphviz_available = find_spec('pygraphviz')

[docs]class ModuleArchitectureRenderer(ModuleArchitecture): """Class for instantiating a ModuleArchitectureRenderer objects useful for creating module architecture diagrams."""
[docs] @staticmethod def get_config_parser(): """Returns a ModuleArchitectureRenderer configuration parser.""" parser = ModuleArchitecture.get_config_parser() parser.description = ModuleArchitectureRenderer.__doc__ # render options # group_render = parser.add_argument_group('Rendering related options') group_render.add_argument('--save_pdf', default=False, type=bool, help='Whether to write rendered pdf file to output directory.') group_render.add_argument('--save_gv', default=False, type=bool, help='Whether to write graphviz file to output directory.') group_render.add_argument('--block_attrs', default={'Default': 'shape=box', 'Input': 'shape=box, style=rounded, penwidth=1.5', 'Output': 'shape=box, style=rounded, peripheries=2', 'Nested': 'shape=box, style=dashed', 'Shared': 'style=filled', 'Reshape': 'shape=hexagon', 'Identity': 'shape=circle, width=0', 'Add': 'shape=circle, margin=0, width=0'}, action=ActionJsonSchema(schema={'type': 'object', 'items': {'type': 'string'}}), help='Attributes for block nodes.') group_render.add_argument('--block_labels', default={'Identity': '', 'Add': '+'}, action=ActionJsonSchema(schema={'type': 'object', 'items': {'type': 'string'}}), help='Fixed labels for block nodes.') group_render.add_argument('--edge_attrs', default='fontsize=10', help='Attributes for edges.') group_render.add_argument('--nested_depth', default=NonNegativeInt(3), type=NonNegativeInt, help='Maximum depth for nested subblocks to render. Set to 0 for unlimited.') group_render.add_argument('--full_ids', default=False, type=bool, help='Whether block IDs should include parent prefix.') group_render.add_argument('--layout_prog', choices=['dot', 'neato', 'twopi', 'circo', 'fdp'], default='dot', help='The graphviz layout method to use.') return parser
[docs] def apply_config(self, cfg: Union[str, dict, Namespace]): """Applies a configuration to the ModuleArchitectureRenderer instance. Args: cfg: Path to config file or config object. """ super().apply_config(cfg) if hasattr(cfg, 'block_attrs'): # @todo support also dict block_attrs = {} for block, attrs in cfg.block_attrs.items() if hasattr(cfg.block_attrs, 'items') else vars(cfg.block_attrs).items(): attrs_dict = {} for a, v in [x.split('=') for x in re.split(', *', attrs)]: attrs_dict[a] = v block_attrs[block] = attrs_dict if 'Default' not in block_attrs: block_attrs['Default'] = {'shape': 'box'} self.block_attrs = block_attrs if hasattr(cfg, 'edge_attrs'): edge_attrs = {} for a, v in [x.split('=') for x in re.split(', *', cfg.edge_attrs)]: edge_attrs[a] = v self.edge_attrs = edge_attrs
@staticmethod def _set_architecture_description(graph, architecture): """Sets the architecture description to a graph as a label.""" if hasattr(architecture, '_description'): description = architecture._description description = '<BR />'.join(textwrap.wrap(description, width=100)) graph.graph_attr['label'] = f'<{description}>' graph.graph_attr['labelloc'] = 't' graph.graph_attr['labeljust'] = 'l' @staticmethod def _set_node_description(graph, node, full_ids=False): """Sets a node description as a label.""" node_id = node._id if full_ids else node._id.split(id_separator)[-1] description = node_id if hasattr(node, '_description'): description = '<BR />'.join(textwrap.wrap(node._description, width=50)) description = f'<{node_id}<FONT POINT-SIZE="6"><BR />{description}</FONT>>' graph.get_node(node._id).attr['label'] = description def _set_edge_label(self, graph, blocks, node_from, node_to, subblock=False): """Sets the shape dimensions to an edge as its label.""" block_from = blocks[node_from] if hasattr(block_from, '_shape'): shape = get_shape('out', block_from) shape = ' × '.join(str(sympify_variable(d)) for d in shape) edge = graph.get_edge(node_from, node_to) edge.attr['label'] = ' '+shape edge_attrs = self.edge_attrs for a, v in edge_attrs.items(): edge.attr[a] = v def _set_block_label(self, graph, block, graph_attr=False, full_ids=False): """Sets a block's label including its id and properties.""" if hasattr(self.cfg.block_labels, block._class): label = getattr(self.cfg.block_labels, block._class) else: exclude = {'output_feats', 'graph', 'input', 'output', 'architecture'} name = block._class if hasattr(block, '_name'): name = block._name props = '' if hasattr(block, '_id'): block_id = block._id if full_ids else block._id.split(id_separator)[-1] props += f'<BR />id: {block_id}' for key in {'_path', '_id_share'}: if hasattr(block, key): props += f'<BR />{key[1:]}: {getattr(block, key)}' def norm_prop(val): if isinstance(val, Namespace): val = namespace_to_dict(val) elif isinstance(val, list): val = [namespace_to_dict(v) if isinstance(v, Namespace) else v for v in val] return str(val) for k, v in vars(block).items(): if not k.startswith('_') and k not in exclude: if block._class in {'Sequential', 'Group'} and k == 'blocks': props += f'<BR />{k}: {len(v)}' else: props += f'<BR />{k}: {norm_prop(v)}' if props != '': label = f'<{name}<FONT POINT-SIZE="6">{props}</FONT>>' else: label = name if graph_attr: graph.graph_attr['label'] = label else: graph.get_node(block._id).attr['label'] = label def _set_block_attrs(self, graph, blocks, block_class=None, graph_attr=False): """Sets graph style attributes to a block.""" block_attrs = self.block_attrs for block in blocks: attrs_class = block._class if block_class is None else block_class attrs = block_attrs[attrs_class] if attrs_class in block_attrs else block_attrs['Default'] for a, v in attrs.items(): if graph_attr: graph.graph_attr[a] = v else: graph.get_node(block._id).attr[a] = v if hasattr(block, '_id_share') and 'Shared' in block_attrs: for a, v in block_attrs['Shared'].items(): if graph_attr: graph.graph_attr[a] = v else: graph.get_node(block._id).attr[a] = v
[docs] def create_graph(self): """Creates a pygraphviz graph of the architecture using the current configuration.""" architecture = self.architecture blocks = self.blocks ## Create raw graph ## if not pygraphviz_available: raise ImportError('pygraphviz package is required by create_graph method.') from pygraphviz import AGraph graph = AGraph('\n'.join(['digraph {']+architecture.graph+['}'])) ## Add architecture description ## self._set_architecture_description(graph, architecture) ## Set attributes of blocks ## self._set_block_attrs(graph, architecture.inputs, block_class='Input') self._set_block_attrs(graph, architecture.outputs, block_class='Output') self._set_block_attrs(graph, architecture.blocks) ## Add input/output descriptions ## for node in itertools.chain(architecture.inputs, architecture.outputs): self._set_node_description(graph, node) ## Add tensor shapes to edges ## for node_from, node_to in graph.edges(): self._set_edge_label(graph, blocks, node_from, node_to) ## Set block properties ## for block in architecture.blocks: self._set_block_label(graph, block) ## Create subgraphs ## self._add_subgraphs(graph, architecture.blocks, dict(blocks), depth=2) return graph
def _add_subgraphs(self, graph, blocks, subblocks_dict, depth, parent_graph=None): """Adds subgraphs to a graph if the depth is not higher that configured value.""" if depth > self.cfg.nested_depth and not self.cfg.nested_depth == 0: return if parent_graph is None: parent_graph = graph full_ids = self.cfg.full_ids for block in [b for b in blocks if b._class in {'Sequential', 'Group', 'Module'}]: ## Remove edges and node ## edges = graph.edges(block._id) edges_from = [(u, v) for u, v in edges if v == block._id] edges_to = [(u, v) for u, v in edges if u == block._id] for edge in edges: graph.remove_edge(*edge) graph.remove_node(block._id) ## Create subgraph cluster ## subgraph = parent_graph.add_subgraph(name='cluster_'+block._id, labeljust='r', labelloc='t') self._set_block_label(subgraph, block, graph_attr=True, full_ids=full_ids) self._set_block_attrs(subgraph, [block], block_class='Nested', graph_attr=True) ## Handle Module ## if block._class == 'Module': subblocks_dict.update(get_blocks_dict(block.architecture.inputs+block.architecture.outputs)) input_id = block.architecture.inputs[0]._id subgraph.add_node(input_id) self._set_node_description(graph, subblocks_dict[input_id], full_ids=full_ids) self._set_block_attrs(graph, [subblocks_dict[input_id]], block_class='Input') graph.add_edge(edges_from[0][0], input_id) self._set_edge_label(graph, subblocks_dict, edges_from[0][0], input_id, subblock=True) output_id = block.architecture.outputs[0]._id subgraph.add_node(output_id) self._set_node_description(graph, subblocks_dict[output_id], full_ids=full_ids) self._set_block_attrs(graph, [subblocks_dict[output_id]], block_class='Output') graph.add_edge(output_id, edges_to[0][1]) self._set_edge_label(graph, subblocks_dict, output_id, edges_to[0][1], subblock=True) block = block.architecture edges_from[0] = (input_id, edges_from[0][1]) edges_to = [] ## Add subblocks nodes and edges ## if not self.cfg.propagated: add_ids_prefix(block, []) subblocks_dict.update(get_blocks_dict(block.blocks)) blocks_from = [subblocks_dict[edges_from[0][0]]] topological_predecessors = parse_graph(blocks_from, block) for subblock_id, prev_ids in topological_predecessors.items(): subgraph.add_node(subblock_id) subblock = subblocks_dict[subblock_id] if hasattr(subblock, '_class'): self._set_block_label(subgraph, subblock, full_ids=full_ids) for node_id_prev in prev_ids: graph.add_edge(node_id_prev, subblock_id) self._set_edge_label(graph, subblocks_dict, node_id_prev, subblock_id, subblock=True) self._set_block_attrs(graph, block.blocks) ## Add final edges ## for u, v in edges_to: graph.add_edge(subblock_id, v) self._set_edge_label(graph, subblocks_dict, subblock_id, v, subblock=True) ## Add subgraphs ## self._add_subgraphs(graph, block.blocks, subblocks_dict, depth=depth+1, parent_graph=subgraph)
[docs] def render( self, architecture: Union[str, Path] = None, out_render: Union[str, Path] = None, cfg: Namespace = None, ): """Renders the architecture diagram optionally writing to the given file path. Args: architecture: Path to a jsonnet architecture file. out_render: Path where to write the rendered diagram with a valid extension for pygraphviz to determine the type. cfg: Configuration to apply before rendering. Returns: AGraph: pygraphviz graph object. """ if cfg is not None: self.apply_config(cfg) if architecture is not None: self.load_architecture(architecture) graph = self.create_graph() outdir = self.cfg.outdir if isinstance(self.cfg.outdir, str) else self.cfg.outdir() if self.cfg.save_gv: out_gv = os.path.join(outdir, self.architecture._id + '.gv') self._check_overwrite(out_gv) graph.write(out_gv) graph.layout(prog=self.cfg.layout_prog) if self.cfg.save_pdf: out_pdf = os.path.join(outdir, self.architecture._id + '.pdf') self._check_overwrite(out_pdf) graph.draw(out_pdf) if out_render is not None: if not isinstance(out_render, str): out_render = out_render() self._check_overwrite(out_render) graph.draw(out_render) return graph