import os import os.path as opath import textwrap from collections import ChainMap from importlib import import_module from io import StringIO from typing import List from yapf.yapflib.yapf_api import FormatCode # Source code utilities # ===================== def format_source(input_source): """ Use yapf to format a string containing Python source code Parameters ---------- input_source : str String containing Python source code Returns ------- String containing yapf-formatted python source code """ style_config = {'based_on_style': 'google', 'DEDENT_CLOSING_BRACKETS': True, 'COLUMN_LIMIT': 79} formatted_source, _ = FormatCode(input_source, style_config=style_config) return formatted_source def format_and_write_source_py(py_source, filepath): """ Format Python source code and write to a file, creating parent directories as needed. Parameters ---------- py_source : str String containing valid Python source code. If string is empty, no file will be written. filepath : str Full path to the file to be written Returns ------- None """ if py_source: try: formatted_source = format_source(py_source) except Exception as e: print(py_source) raise e # Make dir if needed # ------------------ filedir = opath.dirname(filepath) os.makedirs(filedir, exist_ok=True) # Write file # ---------- with open(filepath, 'wt') as f: f.write(formatted_source) def build_from_imports_py(imports_info): """ Build a string containing a series of `from X import Y` lines Parameters ---------- imports_info : str or list of (str, str or list of str) List of import info If element is a pair first entry is the package to be imported from and the second entry is either a string of the single name to be If element is a string, insert string directly Returns ------- str String containing a series of imports """ buffer = StringIO() for import_info in imports_info: if isinstance(import_info, tuple): from_pkg, class_name = import_info if isinstance(class_name, str): class_name_str = class_name else: class_name_str = '(' + ', '.join(class_name) + ')' buffer.write(f"""\ from {from_pkg} import {class_name_str}\n""") elif isinstance(import_info, str): buffer.write(import_info) return buffer.getvalue() def write_init_py(pkg_root, path_parts, import_pairs): """ Build __init__.py source code and write to a file Parameters ---------- pkg_root : str Root package in which the top-level an __init__.py file with empty path_parts should reside path_parts : tuple of str Tuple of sub-packages under pkg_root where the __init__.py file should be written import_pairs : list of (str, str or list of str) List of pairs where first entry is the package to be imported from. The second entry is either a string of the single name to be imported, or a list of names to be imported. Returns ------- None """ # Generate source code # -------------------- init_source = build_from_imports_py(import_pairs) # Write file # ---------- filepath = opath.join(pkg_root, *path_parts, '__init__.py') format_and_write_source_py(init_source, filepath) # Constants # ========= # Mapping from full property paths to custom validator classes CUSTOM_VALIDATOR_DATATYPES = { 'layout.image.source': '_plotly_utils.basevalidators.ImageUriValidator', 'frame.data': 'plotly.validators.DataValidator', 'frame.layout': 'plotly.validators.LayoutValidator' } # Add custom dash validators CUSTOM_VALIDATOR_DATATYPES.update( {prop: '_plotly_utils.basevalidators.DashValidator' for prop in [ 'scatter.line.dash', 'histogram2dcontour.line.dash', 'scattergeo.line.dash', 'scatterpolar.line.dash', 'ohlc.line.dash', 'ohlc.decreasing.line.dash', 'ohlc.increasing.line.dash', 'contourcarpet.line.dash', 'contour.line.dash', 'scatterternary.line.dash', 'scattercarpet.line.dash']}) # Mapping from property string (as found in plot-schema.json) to a custom # class name. If not included here, names are converted to TitleCase and # underscores are removed. OBJECT_NAME_TO_CLASS_NAME = { 'angularaxis': 'AngularAxis', 'colorbar': 'ColorBar', 'error_x': 'ErrorX', 'error_y': 'ErrorY', 'error_z': 'ErrorZ', 'histogram2d': 'Histogram2d', 'histogram2dcontour': 'Histogram2dContour', 'mesh3d': 'Mesh3d', 'radialaxis': 'RadialAxis', 'scatter3d': 'Scatter3d', 'xaxis': 'XAxis', 'xbins': 'XBins', 'yaxis': 'YAxis', 'ybins': 'YBins', 'zaxis': 'ZAxis' } # Tuple of types to be considered dicts by PlotlyNode logic dict_like = (dict, ChainMap) # PlotlyNode classes # ================== class PlotlyNode: """ Base class that represents a node in the plot-schema.json file """ # Constructor # ----------- def __init__(self, plotly_schema, node_path=(), parent=None): """ Superclass constructor for all node types Parameters ---------- plotly_schema : dict JSON-parsed version of the default-schema.xml file node_path : str or tuple Path of from the 'root' node for the current trace type to the particular node that this instance represents parent : PlotlyNode Reference to the node's parent """ # Save params # ----------- self.plotly_schema = plotly_schema self._parent = parent # ### Process node path ### if isinstance(node_path, str): node_path = (node_path,) self.node_path = node_path # Compute children # ---------------- # Note the node_data is a property that must be computed by the # subclass based on plotly_schema and node_path if isinstance(self.node_data, dict_like): self._children = [self.__class__(self.plotly_schema, node_path=self.node_path + (c,), parent=self) for c in self.node_data if c and c[0] != '_'] # Sort by plotly name self._children = sorted(self._children, key=lambda node: node.plotly_name) else: self._children = [] # Magic methods # ------------- def __repr__(self): return self.path_str # Abstract methods # ---------------- @property def node_data(self): """ Dictionary of the subtree of the plotly_schema that this node represents Returns ------- dict """ raise NotImplementedError() @property def description(self): """ Description of the node Returns ------- str or None """ raise NotImplementedError() @property def name_base_datatype(self): """ Superclass to use when generating a datatype class for this node Returns ------- str """ raise NotImplementedError # Names # ----- @property def root_name(self): """ Name of the node with empty node_path Returns ------- str """ raise NotImplementedError() @property def plotly_name(self) : """ Name of the node. Either the base_name or the name directly out of the plotly_schema Returns ------- str """ if len(self.node_path) == 0: return self.root_name else: return self.node_path[-1] @property def name_datatype_class(self): """ Name of the Python datatype class representing this node Returns ------- str """ if self.plotly_name in OBJECT_NAME_TO_CLASS_NAME: return OBJECT_NAME_TO_CLASS_NAME[self.plotly_name] else: return self.plotly_name.title().replace('_', '') @property def name_undercase(self): """ Name of node converted to undercase (all lowercase with underscores separating words) Returns ------- str """ if not self.plotly_name: # Empty plotly_name return self.plotly_name # Lowercase leading char # ---------------------- name1 = self.plotly_name[0].lower() + self.plotly_name[1:] # Replace capital chars by underscore-lower # ----------------------------------------- name2 = ''.join([('' if not c.isupper() else '_') + c.lower() for c in name1]) return name2 @property def name_property(self): """ Name of the Python property corresponding to this node. This is the same as `name_undercase` for compound nodes, but an 's' is appended to the name for array nodes Returns ------- str """ return self.plotly_name + ('s' if self.is_array_element else '') @property def name_validator_class(self) -> str: """ Name of the Python validator class representing this node Returns ------- str """ return (self.name_datatype_class + ('s' if self.is_array_element else '') + 'Validator') @property def name_base_validator(self) -> str: """ Superclass to use when generating a validator class for this node Returns ------- str """ if self.path_str in CUSTOM_VALIDATOR_DATATYPES: validator_base = f"{CUSTOM_VALIDATOR_DATATYPES[self.path_str]}" elif self.plotly_name.endswith('src') and self.datatype == 'string': validator_base = (f"_plotly_utils.basevalidators." f"SrcValidator") else: datatype_title_case = self.datatype.title().replace('_', '') validator_base = (f"_plotly_utils.basevalidators." f"{datatype_title_case}Validator") return validator_base # Validators # ---------- def get_validator_params(self): """ Get kwargs to pass to the constructor of this node's validator superclass. Returns ------- dict The keys are strings matching the names of the constructor params of this node's validator superclass. The values are repr-strings of the values to be passed to the constructor. These values are ready to be used to code generate calls to the constructor. The values should be evald before being passed to the constructor directly. """ params = {'plotly_name': repr(self.name_property), 'parent_name': repr(self.parent_path_str)} if self.is_compound: params['data_class_str'] = repr(self.name_datatype_class) params['data_docs'] = ( '\"\"\"' + self.get_constructor_params_docstring() + '\"\"\"') else: assert self.is_simple # Exclude general properties excluded_props = ['valType', 'description', 'dflt'] if self.datatype == 'subplotid': # Default is required for subplotid validator excluded_props.remove('dflt') attr_nodes = [n for n in self.simple_attrs if n.plotly_name not in excluded_props] for node in attr_nodes: params[node.name_undercase] = repr(node.node_data) # Add extra properties if self.datatype == 'color' and self.parent: # Check for colorscale sibling. We use the presence of a # colorscale sibling to determine whether numeric color # values are permissible colorscale_node_list = [node for node in self.parent.child_datatypes if node.datatype == 'colorscale'] if colorscale_node_list: colorscale_path = colorscale_node_list[0].path_str params['colorscale_path'] = repr(colorscale_path) elif self.datatype == 'literal': params['val'] = self.node_data return params def get_validator_instance(self): """ Return a constructed validator for this node Returns ------- BaseValidator """ # Evaluate validator params to convert repr strings into values # e.g. '2' -> 2 params = {prop: eval(repr_val) for prop, repr_val in self.get_validator_params().items()} validator_parts = self.name_base_validator.split('.') if validator_parts[0] != '_plotly_utils': return None else: validator_class_str = validator_parts[-1] validator_module = '.'.join(validator_parts[:-1]) validator_class = getattr(import_module(validator_module), validator_class_str) return validator_class(**params) # Datatypes # --------- @property def datatype(self) -> str: """ Datatype string for this node. One of 'compound_array', 'compound', 'literal', or the value of the 'valType' attribute Returns ------- str """ if self.is_array_element: return 'compound_array' elif self.is_compound: return 'compound' elif self.is_simple: return self.node_data.get('valType') else: return 'literal' @property def is_array_ok(self) -> bool: """ Return true if arrays of datatype are acceptable Returns ------- bool """ return self.node_data.get('arrayOk', False) @property def is_compound(self) -> bool: """ Node has a compound (in contrast to simple) datatype. Note: All array and array_element types are also considered compound Returns ------- bool """ return (isinstance(self.node_data, dict_like) and not self.is_simple and self.plotly_name not in ('items', 'impliedEdits', 'transforms')) @property def is_literal(self) -> bool: """ Node has a particular literal value (e.g. 'foo', or 23) Returns ------- bool """ return isinstance(self.node_data, (str, int, float)) @property def is_simple(self) -> bool: """ Node has a simple datatype (e.g. boolean, color, colorscale, etc.) Returns ------- bool """ return (isinstance(self.node_data, dict_like) and 'valType' in self.node_data and self.plotly_name != 'items') @property def is_array(self) -> bool: """ Node has an array datatype Returns ------- bool """ return (isinstance(self.node_data, dict_like) and self.node_data.get('role', '') == 'object' and 'items' in self.node_data and self.name_property != 'transforms') @property def is_array_element(self): """ Node has an array-element datatype Returns ------- bool """ if self.parent and self.parent.parent: return self.parent.parent.is_array else: return False @property def is_datatype(self) -> bool: """ Node represents any kind of datatype Returns ------- bool """ return self.is_simple or self.is_compound or self.is_array # Node path # --------- def tidy_path_part(self, p): """ Return a tidy version of raw path entry. This allows subclasses to adjust the raw property names in the plotly_schema Parameters ---------- p : str Path element string Returns ------- str """ return p @property def path_parts(self): """ Tuple of strings locating this node in the plotly_schema e.g. ('layout', 'images', 'opacity') Returns ------- tuple of str """ res = [self.root_name] if self.root_name else [] for i, p in enumerate(self.node_path): # Handle array datatypes if (p == 'items' or (i < len(self.node_path) - 1 and self.node_path[i+1] == 'items')): # e.g. [parcoords, dimensions, items, dimension] -> # [parcoords, dimension] pass else: res.append(self.tidy_path_part(p)) return tuple(res) # Node path strings # ----------------- @property def path_str(self): """ String containing path_parts joined on periods e.g. 'layout.images.opacity' Returns ------- str """ return '.'.join(self.path_parts) @property def dotpath_str(self): """ path_str prefixed by a period if path_str is not empty, otherwise empty Returns ------- str """ path_str = '' for p in self.path_parts: path_str += '.' + p return path_str @property def parent_path_parts(self): """ Tuple of strings locating this node's parent in the plotly_schema Returns ------- tuple of str """ return self.path_parts[:-1] @property def parent_path_str(self): """ String containing parent_path_parts joined on periods Returns ------- str """ return '.'.join(self.path_parts[:-1]) @property def parent_dotpath_str(self): """ parent_path_str prefixed by a period if parent_path_str is not empty, otherwise empty Returns ------- str """ path_str = '' for p in self.parent_path_parts: path_str += '.' + p return path_str # Children # -------- @property def parent(self): """ Parent node Returns ------- PlotlyNode """ return self._parent @property def children(self): """ List of all child nodes Returns ------- list of PlotlyNode """ return self._children @property def simple_attrs(self): """ List of simple attribute child nodes (only valid when is_simple == True) Returns ------- list of PlotlyNode """ if not self.is_simple: raise ValueError( f"Cannot get simple attributes of the simple object '{self.path_str}'") return [n for n in self.children if n.plotly_name not in ['valType', 'description']] @property def child_datatypes(self): """ List of all datatype child nodes Returns ------- list of PlotlyNode """ nodes = [] for n in self.children: if n.is_array: nodes.append(n.children[0].children[0]) elif n.is_datatype: nodes.append(n) return nodes @property def child_compound_datatypes(self): """ List of all compound datatype child nodes Returns ------- list of PlotlyNode """ return [n for n in self.child_datatypes if n.is_compound] @property def child_simple_datatypes(self) -> List['PlotlyNode']: """ List of all simple datatype child nodes Returns ------- list of PlotlyNode """ return [n for n in self.child_datatypes if n.is_simple] @property def child_literals(self) -> List['PlotlyNode']: """ List of all literal child nodes Returns ------- list of PlotlyNode """ return [n for n in self.children if n.is_literal] def get_constructor_params_docstring(self, indent=12): """ Return a docstring-style string containing the names and descriptions of all of the node's child datatypes Parameters ---------- indent : int Leading indent of the string Returns ------- str """ assert self.is_compound buffer = StringIO() subtype_nodes = self.child_datatypes for subtype_node in subtype_nodes: raw_description = subtype_node.description if raw_description: subtype_description = raw_description elif subtype_node.is_compound: class_name = (f'plotly.graph_objs' f'{subtype_node.parent_dotpath_str}.' f'{subtype_node.name_datatype_class}') subtype_description = (f'{class_name} instance or ' 'dict with compatible properties') else: subtype_description = '' subtype_description = '\n'.join( textwrap.wrap(subtype_description, initial_indent=' ' * (indent + 4), subsequent_indent=' ' * (indent + 4), width=79 - (indent + 4))) buffer.write('\n' + ' ' * indent + subtype_node.name_property) buffer.write('\n' + subtype_description) return buffer.getvalue() # Static helpers # -------------- @staticmethod def get_all_compound_datatype_nodes(plotly_schema, node_class): """ Build a list of the entire hierarchy of compound datatype nodes for a given PlotlyNode subclass Parameters ---------- plotly_schema : dict JSON-parsed version of the default-schema.xml file node_class PlotlyNode subclass Returns ------- list of PlotlyNode """ nodes = [] nodes_to_process = [node_class(plotly_schema)] while nodes_to_process: node = nodes_to_process.pop() if node.plotly_name and not node.is_array: nodes.append(node) nodes_to_process.extend(node.child_compound_datatypes) return nodes @staticmethod def get_all_datatype_nodes(plotly_schema, node_class): """ Build a list of the entire hierarchy of datatype nodes for a given PlotlyNode subclass Parameters ---------- plotly_schema : dict JSON-parsed version of the default-schema.xml file node_class PlotlyNode subclass Returns ------- list of PlotlyNode """ nodes = [] nodes_to_process = [node_class(plotly_schema)] while nodes_to_process: node = nodes_to_process.pop() if node.plotly_name and not node.is_array: nodes.append(node) nodes_to_process.extend(node.child_datatypes) return nodes class TraceNode(PlotlyNode): """ Class representing datatypes in the trace hierarchy """ # Constructor # ----------- def __init__(self, plotly_schema, node_path=(), parent=None): super().__init__(plotly_schema, node_path, parent) @property def name_base_datatype(self): if len(self.node_path) <= 1: return 'BaseTraceType' else: return 'BaseTraceHierarchyType' @property def root_name(self): return '' # Raw data # -------- @property def node_data(self) -> dict: if not self.node_path: node_data = self.plotly_schema['traces'] else: trace_name = self.node_path[0] node_data = self.plotly_schema['traces'][trace_name]['attributes'] for prop_name in self.node_path[1:]: node_data = node_data[prop_name] return node_data # Description # ----------- @property def description(self) -> str: if len(self.node_path) == 0: desc = "" elif len(self.node_path) == 1: # Get trace descriptions trace_name = self.node_path[0] desc = (self.plotly_schema['traces'][trace_name] ['meta'].get('description', '')) else: # Get datatype description desc = self.node_data.get('description', '') if isinstance(desc, list): desc = ''.join(desc) return desc class LayoutNode(PlotlyNode): """ Class representing datatypes in the layout hierarchy """ # Constructor # ----------- def __init__(self, plotly_schema, node_path=(), parent=None): # Get main layout properties layout = plotly_schema['layout']['layoutAttributes'] # Get list of additional layout properties for each trace trace_layouts = [ plotly_schema['traces'][trace].get('layoutAttributes', {}) for trace in plotly_schema['traces']] # Chain together into layout_data self.layout_data = ChainMap(layout, *trace_layouts) # Call superclass constructor super().__init__(plotly_schema, node_path, parent) @property def name_base_datatype(self): if len(self.node_path) == 0: return 'BaseLayoutType' else: return 'BaseLayoutHierarchyType' @property def root_name(self): return 'layout' @property def plotly_name(self) -> str: if len(self.node_path) == 0: return self.root_name else: return self.node_path[-1] # Description # ----------- @property def description(self) -> str: desc = self.node_data.get('description', '') if isinstance(desc, list): desc = ''.join(desc) return desc # Raw data # -------- @property def node_data(self) -> dict: node_data = self.layout_data for prop_name in self.node_path: node_data = node_data[prop_name] return node_data class FrameNode(PlotlyNode): """ Class representing datatypes in the frames hierarchy """ # Constructor # ----------- def __init__(self, plotly_schema, node_path=(), parent=None): super().__init__(plotly_schema, node_path, parent) @property def name_base_datatype(self): return 'BaseFrameHierarchyType' @property def root_name(self): return '' @property def plotly_name(self) -> str: if len(self.node_path) < 2: return self.root_name elif len(self.node_path) == 2: return 'frame' # override 'frames_entry' else: return self.node_path[-1] def tidy_path_part(self, p): return 'frame' if p == 'frames_entry' else p # Description # ----------- @property def description(self) -> str: desc = self.node_data.get('description', '') if isinstance(desc, list): desc = ''.join(desc) return desc # Raw data # -------- @property def node_data(self) -> dict: node_data = self.plotly_schema['frames'] for prop_name in self.node_path: node_data = node_data[prop_name] return node_data