diff --git a/codegen/figure.py b/codegen/figure.py index f148482a674..494f11980e9 100644 --- a/codegen/figure.py +++ b/codegen/figure.py @@ -68,7 +68,7 @@ class {fig_classname}({base_classname}):\n""") buffer.write(f""" def __init__(self, data=None, layout=None, - frames=None, skip_invalid=False): + frames=None, skip_invalid=False, **kwargs): \"\"\" Create a new {fig_classname} instance @@ -95,7 +95,8 @@ def __init__(self, data=None, layout=None, is invalid AND skip_invalid is False \"\"\" super({fig_classname} ,self).__init__(data, layout, - frames, skip_invalid) + frames, skip_invalid, + **kwargs) """) # ### add_trace methods for each trace type ### diff --git a/plotly/basedatatypes.py b/plotly/basedatatypes.py index 5f24d95e4a6..65a2c2857c3 100644 --- a/plotly/basedatatypes.py +++ b/plotly/basedatatypes.py @@ -9,7 +9,11 @@ from contextlib import contextmanager from copy import deepcopy, copy -from plotly.subplots import _set_trace_grid_reference, _get_grid_subplot +from plotly.subplots import ( + _set_trace_grid_reference, + _get_grid_subplot, + _get_subplot_ref_for_trace, + _validate_v4_subplots) from .optional_imports import get_module from _plotly_utils.basevalidators import ( @@ -33,13 +37,25 @@ class BaseFigure(object): """ _bracket_re = re.compile('^(.*)\[(\d+)\]$') + _valid_underscore_properties = { + 'error_x': 'error-x', + 'error_y': 'error-y', + 'error_z': 'error-z', + 'copy_xstyle': 'copy-xstyle', + 'copy_ystyle': 'copy-ystyle', + 'copy_zstyle': 'copy-zstyle', + 'paper_bgcolor': 'paper-bgcolor', + 'plot_bgcolor': 'plot-bgcolor' + } + # Constructor # ----------- def __init__(self, data=None, layout_plotly=None, frames=None, - skip_invalid=False): + skip_invalid=False, + **kwargs): """ Construct a BaseFigure object @@ -247,6 +263,14 @@ class is a subclass of both BaseFigure and widgets.DOMWidget. # ### Check for default template ### self._initialize_layout_template() + # Process kwargs + # -------------- + for k, v in kwargs.items(): + if k in self: + self[k] = v + elif not skip_invalid: + raise TypeError('invalid Figure property: {}'.format(k)) + # Magic Methods # ------------- def __reduce__(self): @@ -356,7 +380,13 @@ def __iter__(self): return iter(('data', 'layout', 'frames')) def __contains__(self, prop): - return prop in ('data', 'layout', 'frames') + prop = BaseFigure._str_to_dict_path(prop) + if prop[0] not in ('data', 'layout', 'frames'): + return False + elif len(prop) == 1: + return True + else: + return prop[1:] in self[prop[0]] def __eq__(self, other): if not isinstance(other, BaseFigure): @@ -447,7 +477,8 @@ def update(self, dict1=None, **kwargs): for d in [dict1, kwargs]: if d: for k, v in d.items(): - if self[k] == (): + update_target = self[k] + if update_target == (): # existing data or frames property is empty # In this case we accept the v as is. if k == 'data': @@ -455,9 +486,12 @@ def update(self, dict1=None, **kwargs): else: # Accept v self[k] = v - else: + elif (isinstance(update_target, BasePlotlyType) or + (isinstance(update_target, tuple) and + isinstance(update_target[0], BasePlotlyType))): BaseFigure._perform_update(self[k], v) - + else: + self[k] = v return self # Data @@ -604,6 +638,140 @@ def data(self, new_data): for trace_ind, trace in enumerate(self._data_objs): trace._trace_ind = trace_ind + def select_traces(self, selector=None, row=None, col=None): + """ + Select traces from a particular subplot cell and/or traces + that satisfy custom selection criteria. + + Parameters + ---------- + selector: dict or None (default None) + Dict to use as selection criteria. + Traces will be selected if they contain properties corresponding + to all of the dictionary's keys, with values that exactly match + the supplied values. If None (the default), all traces are + selected. + row, col: int or None (default None) + Subplot row and column index of traces to select. + To select traces by row and column, the Figure must have been + created using plotly.subplots.make_subplots. If None + (the default), all traces are selected. + + Returns + ------- + generator + Generator that iterates through all of the traces that satisfy + all of the specified selection criteria + """ + if not selector: + selector = {} + + if row is not None and col is not None: + _validate_v4_subplots('select_traces') + grid_ref = self._validate_get_grid_ref() + grid_subplot_ref = grid_ref[row-1][col-1] + filter_by_subplot = True + else: + filter_by_subplot = False + grid_subplot_ref = None + + return self._perform_select_traces( + filter_by_subplot, grid_subplot_ref, selector) + + def _perform_select_traces( + self, filter_by_subplot, grid_subplot_ref, selector): + + def select_eq(obj1, obj2): + try: + obj1 = obj1.to_plotly_json() + except Exception: + pass + try: + obj2 = obj2.to_plotly_json() + except Exception: + pass + + return BasePlotlyType._vals_equal(obj1, obj2) + + for trace in self.data: + # Filter by subplot + if filter_by_subplot: + trace_subplot_ref = _get_subplot_ref_for_trace(trace) + if grid_subplot_ref != trace_subplot_ref: + continue + + # Filter by selector + if not all( + k in trace and select_eq(trace[k], selector[k]) + for k in selector): + continue + + yield trace + + def for_each_trace(self, fn, selector=None, row=None, col=None): + """ + Apply a function to all traces that satisfy the specified selection + criteria + + Parameters + ---------- + fn: + Function that inputs a single trace object. + selector: dict or None (default None) + Dict to use as selection criteria. + Traces will be selected if they contain properties corresponding + to all of the dictionary's keys, with values that exactly match + the supplied values. If None (the default), all traces are + selected. + row, col: int or None (default None) + Subplot row and column index of traces to select. + To select traces by row and column, the Figure must have been + created using plotly.subplots.make_subplots. If None + (the default), all traces are selected. + + Returns + ------- + self + Returns the Figure object that the method was called on + """ + for trace in self.select_traces(selector=selector, row=row, col=col): + fn(trace) + + return self + + def update_traces(self, patch, selector=None, row=None, col=None): + """ + Perform a property update operation on all traces that satisfy the + specified selection criteria + + Parameters + ---------- + patch: dict + Dictionary of property updates to be applied to all traces that + satisfy the selection criteria. + fn: + Function that inputs a single trace object. + selector: dict or None (default None) + Dict to use as selection criteria. + Traces will be selected if they contain properties corresponding + to all of the dictionary's keys, with values that exactly match + the supplied values. If None (the default), all traces are + selected. + row, col: int or None (default None) + Subplot row and column index of traces to select. + To select traces by row and column, the Figure must have been + created using plotly.subplots.make_subplots. If None + (the default), all traces are selected. + + Returns + ------- + self + Returns the Figure object that the method was called on + """ + for trace in self.select_traces(selector=selector, row=row, col=col): + trace.update(patch) + return self + # Restyle # ------- def plotly_restyle(self, restyle_data, trace_indexes=None, **kwargs): @@ -822,18 +990,20 @@ def _str_to_dict_path(key_path_str): """ if isinstance(key_path_str, string_types) and \ '.' not in key_path_str and \ - '[' not in key_path_str: + '[' not in key_path_str and \ + '_' not in key_path_str: # Fast path for common case that avoids regular expressions return (key_path_str,) elif isinstance(key_path_str, tuple): # Nothing to do return key_path_str else: - # Split string on periods. e.g. 'foo.bar[1]' -> ['foo', 'bar[1]'] + # Split string on periods. + # e.g. 'foo.bar_baz[1]' -> ['foo', 'bar_baz[1]'] key_path = key_path_str.split('.') # Split out bracket indexes. - # e.g. ['foo', 'bar[1]'] -> ['foo', 'bar', '1'] + # e.g. ['foo', 'bar_baz[1]'] -> ['foo', 'bar_baz', '1'] key_path2 = [] for key in key_path: match = BaseFigure._bracket_re.match(key) @@ -842,15 +1012,39 @@ def _str_to_dict_path(key_path_str): else: key_path2.append(key) + # Split out underscore + # e.g. ['foo', 'bar_baz', '1'] -> ['foo', 'bar', 'baz', '1'] + key_path3 = [] + underscore_props = BaseFigure._valid_underscore_properties + for key in key_path2: + if '_' in key[1:]: + # For valid properties that contain underscores (error_x) + # replace the underscores with hyphens to protect them + # from being split up + for under_prop, hyphen_prop in underscore_props.items(): + key = key.replace(under_prop, hyphen_prop) + + # Split key on underscores + key = key.split('_') + + # Replace hyphens with underscores to restore properties + # that include underscores + for i in range(len(key)): + key[i] = key[i].replace('-', '_') + + key_path3.extend(key) + else: + key_path3.append(key) + # Convert elements to ints if possible. # e.g. ['foo', 'bar', '0'] -> ['foo', 'bar', 0] - for i in range(len(key_path2)): + for i in range(len(key_path3)): try: - key_path2[i] = int(key_path2[i]) + key_path3[i] = int(key_path3[i]) except ValueError as _: pass - return tuple(key_path2) + return tuple(key_path3) @staticmethod def _set_in(d, key_path_str, v): @@ -1235,13 +1429,8 @@ def append_trace(self, trace, row, col): self.add_trace(trace=trace, row=row, col=col) def _set_trace_grid_position(self, trace, row, col): - try: - grid_ref = self._grid_ref - except AttributeError: - raise Exception("In order to reference traces by row and column, " - "you must first use " - "plotly.tools.make_subplots " - "to create the figure with a subplot grid.") + grid_ref = self._validate_get_grid_ref() + from _plotly_future_ import _future_flags if 'v4_subplots' in _future_flags: return _set_trace_grid_reference( @@ -1277,6 +1466,18 @@ def _set_trace_grid_position(self, trace, row, col): trace['xaxis'] = ref[0] trace['yaxis'] = ref[1] + def _validate_get_grid_ref(self): + try: + grid_ref = self._grid_ref + if grid_ref is None: + raise AttributeError('_grid_ref') + except AttributeError: + raise Exception("In order to reference traces by row and column, " + "you must first use " + "plotly.tools.make_subplots " + "to create the figure with a subplot grid.") + return grid_ref + def get_subplot(self, row, col): """ Return an object representing the subplot at the specified row @@ -2429,8 +2630,16 @@ def _process_kwargs(self, **kwargs): """ Process any extra kwargs that are not predefined as constructor params """ - if not self._skip_invalid: - self._raise_on_invalid_property_error(*kwargs.keys()) + invalid_kwargs = {} + for k, v in kwargs.items(): + if k in self: + # e.g. underscore kwargs like marker_line_color + self[k] = v + else: + invalid_kwargs[k] = v + + if invalid_kwargs and not self._skip_invalid: + self._raise_on_invalid_property_error(*invalid_kwargs.keys()) @property def plotly_name(self): @@ -2675,7 +2884,9 @@ def _get_prop_validator(self, prop): plotly_obj = self[prop_path[:-1]] prop = prop_path[-1] else: - plotly_obj = self + prop_path = BaseFigure._str_to_dict_path(prop) + plotly_obj = self[prop_path[:-1]] + prop = prop_path[-1] # Return validator # ---------------- diff --git a/plotly/basewidget.py b/plotly/basewidget.py index a99bd7636c9..e0685870596 100644 --- a/plotly/basewidget.py +++ b/plotly/basewidget.py @@ -119,7 +119,8 @@ def __init__(self, data=None, layout=None, frames=None, - skip_invalid=False): + skip_invalid=False, + **kwargs): # Call superclass constructors # ---------------------------- @@ -129,7 +130,8 @@ def __init__(self, super(BaseFigureWidget, self).__init__(data=data, layout_plotly=layout, frames=frames, - skip_invalid=skip_invalid) + skip_invalid=skip_invalid, + **kwargs) # Validate Frames # --------------- diff --git a/plotly/graph_objs/_figure.py b/plotly/graph_objs/_figure.py index d7df69ece91..387485f26a3 100644 --- a/plotly/graph_objs/_figure.py +++ b/plotly/graph_objs/_figure.py @@ -12,7 +12,12 @@ class Figure(BaseFigure): def __init__( - self, data=None, layout=None, frames=None, skip_invalid=False + self, + data=None, + layout=None, + frames=None, + skip_invalid=False, + **kwargs ): """ Create a new Figure instance @@ -500,7 +505,8 @@ def __init__( if a property in the specification of data, layout, or frames is invalid AND skip_invalid is False """ - super(Figure, self).__init__(data, layout, frames, skip_invalid) + super(Figure, + self).__init__(data, layout, frames, skip_invalid, **kwargs) def add_area( self, diff --git a/plotly/graph_objs/_figurewidget.py b/plotly/graph_objs/_figurewidget.py index e33abafcf1f..800571ecca1 100644 --- a/plotly/graph_objs/_figurewidget.py +++ b/plotly/graph_objs/_figurewidget.py @@ -12,7 +12,12 @@ class FigureWidget(BaseFigureWidget): def __init__( - self, data=None, layout=None, frames=None, skip_invalid=False + self, + data=None, + layout=None, + frames=None, + skip_invalid=False, + **kwargs ): """ Create a new FigureWidget instance @@ -500,7 +505,8 @@ def __init__( if a property in the specification of data, layout, or frames is invalid AND skip_invalid is False """ - super(FigureWidget, self).__init__(data, layout, frames, skip_invalid) + super(FigureWidget, + self).__init__(data, layout, frames, skip_invalid, **kwargs) def add_area( self, diff --git a/plotly/subplots.py b/plotly/subplots.py index 31900b0311e..ba489422789 100644 --- a/plotly/subplots.py +++ b/plotly/subplots.py @@ -306,17 +306,7 @@ def make_subplots( ... cols=[1, 2, 1, 2]) """ - # Make sure we're in future subplots mode - from _plotly_future_ import _future_flags - if 'v4_subplots' not in _future_flags: - raise ValueError(""" -plotly.subplots.make_subplots may only be used in the -v4_subplots _plotly_future_ mode. To try it out, run - ->>> from _plotly_future_ import v4_subplots - -before importing plotly. -""") + _validate_v4_subplots('plotly.subplots.make_subplots') import plotly.graph_objs as go @@ -624,11 +614,10 @@ def _checks(item, defaults): subplot_type = spec['type'] grid_ref_element = _init_subplot( layout, subplot_type, x_domain, y_domain, max_subplot_ids) - grid_ref_element['spec'] = spec grid_ref[r][c] = grid_ref_element - _configure_shared_axes(layout, grid_ref, 'x', shared_xaxes, row_dir) - _configure_shared_axes(layout, grid_ref, 'y', shared_yaxes, row_dir) + _configure_shared_axes(layout, grid_ref, specs, 'x', shared_xaxes, row_dir) + _configure_shared_axes(layout, grid_ref, specs, 'y', shared_yaxes, row_dir) # Build inset reference # --------------------- @@ -769,7 +758,21 @@ def _checks(item, defaults): return fig -def _configure_shared_axes(layout, grid_ref, x_or_y, shared, row_dir): +def _validate_v4_subplots(method_name): + # Make sure we're in future subplots mode + from _plotly_future_ import _future_flags + if 'v4_subplots' not in _future_flags: + raise ValueError(""" +{method_name} may only be used in the +v4_subplots _plotly_future_ mode. To try it out, run + +>>> from _plotly_future_ import v4_subplots + +before importing plotly. +""".format(method_name=method_name)) + + +def _configure_shared_axes(layout, grid_ref, specs, x_or_y, shared, row_dir): rows = len(grid_ref) cols = len(grid_ref[0]) @@ -780,14 +783,14 @@ def _configure_shared_axes(layout, grid_ref, x_or_y, shared, row_dir): else: rows_iter = range(rows) - def update_axis_matches(first_axis_id, ref, remove_label): + def update_axis_matches(first_axis_id, ref, spec, remove_label): if ref is None: return first_axis_id if x_or_y == 'x': - span = ref['spec']['colspan'] + span = spec['colspan'] else: - span = ref['spec']['rowspan'] + span = spec['rowspan'] if ref['subplot_type'] == 'xy' and span == 1: if first_axis_id is None: @@ -808,8 +811,9 @@ def update_axis_matches(first_axis_id, ref, remove_label): ok_to_remove_label = x_or_y == 'x' for r in rows_iter: ref = grid_ref[r][c] + spec = specs[r][c] first_axis_id = update_axis_matches( - first_axis_id, ref, ok_to_remove_label) + first_axis_id, ref, spec, ok_to_remove_label) elif shared == 'rows' or (x_or_y == 'y' and shared is True): for r in rows_iter: @@ -817,14 +821,16 @@ def update_axis_matches(first_axis_id, ref, remove_label): ok_to_remove_label = x_or_y == 'y' for c in range(cols): ref = grid_ref[r][c] + spec = specs[r][c] first_axis_id = update_axis_matches( - first_axis_id, ref, ok_to_remove_label) + first_axis_id, ref, spec, ok_to_remove_label) elif shared == 'all': first_axis_id = None for c in range(cols): for ri, r in enumerate(rows_iter): ref = grid_ref[r][c] + spec = specs[r][c] if x_or_y == 'y': ok_to_remove_label = c > 0 @@ -832,7 +838,7 @@ def update_axis_matches(first_axis_id, ref, remove_label): ok_to_remove_label = ri > 0 if row_dir > 0 else r < rows - 1 first_axis_id = update_axis_matches( - first_axis_id, ref, ok_to_remove_label) + first_axis_id, ref, spec, ok_to_remove_label) def _init_subplot_xy( @@ -846,15 +852,15 @@ def _init_subplot_xy( y_cnt = max_subplot_ids['yaxis'] + 1 # Compute x/y labels (the values of trace.xaxis/trace.yaxis - x_label = "x{cnt}".format(cnt=x_cnt) - y_label = "y{cnt}".format(cnt=y_cnt) + x_label = "x{cnt}".format(cnt=x_cnt if x_cnt > 1 else '') + y_label = "y{cnt}".format(cnt=y_cnt if y_cnt > 1 else '') # Anchor x and y axes to each other x_anchor, y_anchor = y_label, x_label # Build layout.xaxis/layout.yaxis containers - xaxis_name = 'xaxis{cnt}'.format(cnt=x_cnt) - yaxis_name = 'yaxis{cnt}'.format(cnt=y_cnt) + xaxis_name = 'xaxis{cnt}'.format(cnt=x_cnt if x_cnt > 1 else '') + yaxis_name = 'yaxis{cnt}'.format(cnt=y_cnt if y_cnt > 1 else '') x_axis = {'domain': x_domain, 'anchor': x_anchor} y_axis = {'domain': y_domain, 'anchor': y_anchor} @@ -882,7 +888,9 @@ def _init_subplot_single( # Add scene to layout cnt = max_subplot_ids[subplot_type] + 1 - label = '{subplot_type}{cnt}'.format(subplot_type=subplot_type, cnt=cnt) + label = '{subplot_type}{cnt}'.format( + subplot_type=subplot_type, + cnt=cnt if cnt > 1 else '') scene = dict(domain={'x': x_domain, 'y': y_domain}) layout[label] = scene @@ -906,12 +914,13 @@ def _init_subplot_domain(x_domain, y_domain): ref_element = { 'subplot_type': 'domain', 'layout_keys': (), - 'trace_kwargs': {'domain': {'x': x_domain, 'y': y_domain}}} + 'trace_kwargs': { + 'domain': {'x': tuple(x_domain), 'y': tuple(y_domain)}}} return ref_element -def _subplot_type_for_trace(trace_type): +def _subplot_type_for_trace_type(trace_type): from plotly.validators import DataValidator trace_validator = DataValidator() if trace_type in trace_validator.class_strs_map: @@ -947,7 +956,7 @@ def _validate_coerce_subplot_type(subplot_type): return subplot_type # Try to determine subplot type for trace - subplot_type = _subplot_type_for_trace(subplot_type) + subplot_type = _subplot_type_for_trace_type(subplot_type) if subplot_type is None: raise ValueError('Unsupported subplot type: {}' @@ -1306,3 +1315,47 @@ def _get_grid_subplot(fig, row, col): else: raise ValueError(""" Unexpected subplot type with layout_keys of {}""".format(layout_keys)) + + +def _get_subplot_ref_for_trace(trace): + + if 'domain' in trace: + return { + 'subplot_type': 'domain', + 'layout_keys': (), + 'trace_kwargs': { + 'domain': {'x': trace.domain.x, + 'y': trace.domain.y}}} + + elif 'xaxis' in trace and 'yaxis' in trace: + xaxis_name = 'xaxis' + trace.xaxis[1:] if trace.xaxis else 'xaxis' + yaxis_name = 'yaxis' + trace.yaxis[1:] if trace.yaxis else 'yaxis' + + return { + 'subplot_type': 'xy', + 'layout_keys': (xaxis_name, yaxis_name), + 'trace_kwargs': {'xaxis': trace.xaxis, 'yaxis': trace.yaxis} + } + elif 'geo' in trace: + return { + 'subplot_type': 'geo', + 'layout_keys': (trace.geo,), + 'trace_kwargs': {'geo': trace.geo}} + elif 'scene' in trace: + return { + 'subplot_type': 'scene', + 'layout_keys': (trace.scene,), + 'trace_kwargs': {'scene': trace.scene}} + elif 'subplot' in trace: + for t in _subplot_prop_named_subplot: + try: + validator = trace._get_prop_validator('subplot') + validator.validate_coerce(t) + return { + 'subplot_type': t, + 'layout_keys': (trace.subplot,), + 'trace_kwargs': {'subplot': trace.subplot}} + except ValueError: + pass + + return None diff --git a/plotly/tests/test_core/test_graph_objs/test_constructor.py b/plotly/tests/test_core/test_graph_objs/test_constructor.py index 2bf7703f7ce..ddb30bc8a6b 100644 --- a/plotly/tests/test_core/test_graph_objs/test_constructor.py +++ b/plotly/tests/test_core/test_graph_objs/test_constructor.py @@ -15,6 +15,11 @@ def test_valid_arg_dict(self): self.assertEqual(m.to_plotly_json(), {'color': 'green'}) + def test_valid_underscore_kwarg(self): + m = go.scatter.Marker(line_color='green') + self.assertEqual(m.to_plotly_json(), + {'line': {'color': 'green'}}) + def test_valid_arg_obj(self): m = go.scatter.Marker( go.scatter.Marker(color='green')) diff --git a/plotly/tests/test_core/test_graph_objs/test_figure.py b/plotly/tests/test_core/test_graph_objs/test_figure.py index d07dd71ba02..8e39160ff4e 100644 --- a/plotly/tests/test_core/test_graph_objs/test_figure.py +++ b/plotly/tests/test_core/test_graph_objs/test_figure.py @@ -55,6 +55,7 @@ def test_skip_invalid_property_name(self): 'data': [{'type': 'bar', 'bogus': 123}], 'layout': {'bogus': 23, 'title': 'Figure title'}, }], + bogus=123, skip_invalid=True) fig_dict = fig.to_dict() @@ -93,7 +94,8 @@ def test_skip_invalid_property_value(self): 'data': [{'type': 'bar', 'showlegend': 'bad_value'}], 'layout': {'bgcolor': 'bad_color', 'title': 'Figure title'}, }], - skip_invalid=True) + skip_invalid=True, + ) fig_dict = fig.to_dict() @@ -110,4 +112,31 @@ def test_skip_invalid_property_value(self): 'data': [{'type': 'bar'}], 'layout': {'title': {'text': 'Figure title'}} - }]) \ No newline at end of file + }]) + + def test_raises_invalid_toplevel_kwarg(self): + with self.assertRaises(TypeError): + go.Figure( + data=[{'type': 'bar'}], + layout={'title': 'Figure title'}, + frames=[{ + 'data': [{'type': 'bar'}], + 'layout': {'title': 'Figure title'}, + }], + bogus=123 + ) + + def test_toplevel_underscore_kwarg(self): + fig = go.Figure( + data=[{'type': 'bar'}], + layout_title_text='Hello, Figure title!' + ) + + self.assertEqual(fig.layout.title.text, 'Hello, Figure title!') + + def test_add_trace_underscore_kwarg(self): + fig = go.Figure() + + fig.add_scatter(y=[2, 1, 3], marker_line_color='green') + + self.assertEqual(fig.data[0].marker.line.color, 'green') diff --git a/plotly/tests/test_core/test_graph_objs/test_figure_properties.py b/plotly/tests/test_core/test_graph_objs/test_figure_properties.py index 02aef463040..4df64690b5f 100644 --- a/plotly/tests/test_core/test_graph_objs/test_figure_properties.py +++ b/plotly/tests/test_core/test_graph_objs/test_figure_properties.py @@ -122,6 +122,42 @@ def test_update_data(self): self.figure.update({'data': {0: {'marker': {'color': 'yellow'}}}}) self.assertEqual(self.figure.data[0].marker.color, 'yellow') + def test_update_data_dots(self): + # Check initial marker color + self.assertEqual(self.figure.data[0].marker.color, 'green') + + # Update with dict kwarg + self.figure.update(data={0: {'marker.color': 'blue'}}) + self.assertEqual(self.figure.data[0].marker.color, 'blue') + + # Update with list kwarg + self.figure.update(data=[{'marker.color': 'red'}]) + self.assertEqual(self.figure.data[0].marker.color, 'red') + + # Update with dict + self.figure.update({'data[0].marker.color': 'yellow'}) + self.assertEqual(self.figure.data[0].marker.color, 'yellow') + + def test_update_data_underscores(self): + # Check initial marker color + self.assertEqual(self.figure.data[0].marker.color, 'green') + + # Update with dict kwarg + self.figure.update(data={0: {'marker_color': 'blue'}}) + self.assertEqual(self.figure.data[0].marker.color, 'blue') + + # Update with list kwarg + self.figure.update(data=[{'marker_color': 'red'}]) + self.assertEqual(self.figure.data[0].marker.color, 'red') + + # Update with dict + self.figure.update({'data_0_marker_color': 'yellow'}) + self.assertEqual(self.figure.data[0].marker.color, 'yellow') + + # Update with kwarg + self.figure.update(data_0_marker_color='yellow') + self.assertEqual(self.figure.data[0].marker.color, 'yellow') + def test_update_data_empty(self): # Create figure with empty data (no traces) figure = go.Figure(layout={'width': 1000}) @@ -240,4 +276,4 @@ def test_plotly_update_validate_property_trace(self): @raises(ValueError) def test_plotly_update_validate_property_layout(self): - self.figure.plotly_update(relayout_data={'xaxis.bogus': [1, 3]}) \ No newline at end of file + self.figure.plotly_update(relayout_data={'xaxis.bogus': [1, 3]}) diff --git a/plotly/tests/test_core/test_graph_objs/test_property_assignment.py b/plotly/tests/test_core/test_graph_objs/test_property_assignment.py index fc76355b338..84669af8ad4 100644 --- a/plotly/tests/test_core/test_graph_objs/test_property_assignment.py +++ b/plotly/tests/test_core/test_graph_objs/test_property_assignment.py @@ -30,6 +30,11 @@ def setUp(self): 'marker': {'colorbar': { 'title': {'font': {'family': 'courier'}}}}} + self.expected_nested_error_x = { + 'type': 'scatter', + 'name': 'scatter A', + 'error_x': {'type': 'percent'}} + def test_toplevel_attr(self): assert self.scatter.fillcolor is None self.scatter.fillcolor = 'green' @@ -86,6 +91,22 @@ def test_nested_update(self): d1, d2 = strip_dict_params(self.scatter, self.expected_nested) assert d1 == d2 + def test_nested_update_dots(self): + assert self.scatter['marker.colorbar.title.font.family'] is None + self.scatter.update({'marker.colorbar.title.font.family': 'courier'}) + + assert self.scatter['marker.colorbar.title.font.family'] == 'courier' + d1, d2 = strip_dict_params(self.scatter, self.expected_nested) + assert d1 == d2 + + def test_nested_update_underscores(self): + assert self.scatter['error_x.type'] is None + self.scatter.update({'error_x_type': 'percent'}) + + assert self.scatter['error_x_type'] == 'percent' + d1, d2 = strip_dict_params(self.scatter, self.expected_nested_error_x) + assert d1 == d2 + class TestAssignmentCompound(TestCase): @@ -453,4 +474,40 @@ def test_assign_double_nested_update_array(self): d1, d2 = strip_dict_params(self.layout, self.expected_layout2) assert d1 == d2 + def test_update_double_nested_dot(self): + self.assertEqual(self.layout.updatemenus, ()) + + # Initialize empty updatemenus + self.layout['updatemenus'] = [{}, {}] + # Initialize empty buttons in updatemenu[1] + self.layout['updatemenus.1.buttons'] = [{}, {}, {}] + + # Update + self.layout.update({'updatemenus[1].buttons[2].method': 'restyle'}) + + # Check + self.assertEqual( + self.layout['updatemenus[1].buttons[2].method'], + 'restyle') + d1, d2 = strip_dict_params(self.layout, self.expected_layout2) + assert d1 == d2 + + def test_update_double_nested_underscore(self): + self.assertEqual(self.layout.updatemenus, ()) + + # Initialize empty updatemenus + self.layout['updatemenus'] = [{}, {}] + + # Initialize empty buttons in updatemenu[1] + self.layout['updatemenus_1_buttons'] = [{}, {}, {}] + + # Update + self.layout.update({'updatemenus_1_buttons_2_method': 'restyle'}) + + # Check + self.assertEqual( + self.layout['updatemenus[1].buttons[2].method'], + 'restyle') + d1, d2 = strip_dict_params(self.layout, self.expected_layout2) + assert d1 == d2 diff --git a/plotly/tests/test_core/test_update_traces/__init__.py b/plotly/tests/test_core/test_update_traces/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/plotly/tests/test_core/test_update_traces/test_update_traces.py b/plotly/tests/test_core/test_update_traces/test_update_traces.py new file mode 100644 index 00000000000..875002bd425 --- /dev/null +++ b/plotly/tests/test_core/test_update_traces/test_update_traces.py @@ -0,0 +1,324 @@ +from __future__ import absolute_import +from unittest import TestCase +import inspect +import copy + +import plotly.graph_objs as go +from plotly.subplots import make_subplots +from _plotly_future_ import _future_flags + + +class TestSelectForEachUpdateTraces(TestCase): + + def setUp(self): + _future_flags.add('v4_subplots') + fig = make_subplots( + rows=3, + cols=2, + specs=[[{}, {'type': 'scene'}], + [{}, {'type': 'polar'}], + [{'type': 'domain', 'colspan': 2}, None]] + ).update(layout={'height': 800}) + + # data[0], (1, 1) + fig.add_scatter( + mode='markers', + y=[2, 3, 1], + name='A', + marker={'color': 'green', 'size': 10}, + row=1, col=1) + + # data[1], (1, 1) + fig.add_bar(y=[2, 3, 1], row=1, col=1, name='B') + + # data[2], (2, 1) + fig.add_scatter( + mode='lines', + y=[1, 2, 0], + line={'color': 'purple'}, + name='C', + row=2, + col=1, + ) + + # data[3], (2, 1) + fig.add_heatmap( + z=[[2, 3, 1], [2, 1, 3], [3, 2, 1]], + row=2, + col=1, + name='D', + ) + + # data[4], (1, 2) + fig.add_scatter3d( + x=[0, 0, 0], + y=[0, 0, 0], + z=[0, 1, 2], + mode='markers', + marker={'color': 'green', 'size': 10}, + name='E', + row=1, + col=2 + ) + + # data[5], (1, 2) + fig.add_scatter3d( + x=[0, 0, -1], + y=[-1, 0, 0], + z=[0, 1, 2], + mode='lines', + line={'color': 'purple', 'width': 4}, + name='F', + row=1, + col=2 + ) + + # data[6], (2, 2) + fig.add_scatterpolar( + mode='markers', + r=[0, 3, 2], + theta=[0, 20, 87], + marker={'color': 'green', 'size': 8}, + name='G', + row=2, + col=2 + ) + + # data[7], (2, 2) + fig.add_scatterpolar( + mode='lines', + r=[0, 3, 2], + theta=[20, 87, 111], + name='H', + row=2, + col=2 + ) + + # data[8], (3, 1) + fig.add_parcoords( + dimensions=[{'values': [1, 2, 3, 2, 1]}, + {'values': [3, 2, 1, 3, 2, 1]}], + line={'color': 'purple'}, + name='I', + row=3, + col=1 + ) + + self.fig = fig + self.fig_no_grid = go.Figure(self.fig.to_dict()) + + def tearDown(self): + _future_flags.remove('v4_subplots') + + # select_traces and for_each_trace + # -------------------------------- + def assert_select_traces(self, expected_inds, selector=None, row=None, col=None, test_no_grid=False): + # Select traces on figure initialized with make_subplots + trace_generator = self.fig.select_traces( + selector=selector, row=row, col=col) + self.assertTrue(inspect.isgenerator(trace_generator)) + + trace_list = list(trace_generator) + self.assertEqual(trace_list, [self.fig.data[i] for i in expected_inds]) + + # Select traces on figure not containing subplot info + if test_no_grid: + trace_generator = self.fig_no_grid.select_traces( + selector=selector, row=row, col=col) + trace_list = list(trace_generator) + self.assertEqual(trace_list, [self.fig_no_grid.data[i] for i in expected_inds]) + + # Test for each trace + trace_list = [] + for_each_res = self.fig.for_each_trace( + lambda t: trace_list.append(t), + selector=selector, + row=row, + col=col, + ) + self.assertIs(for_each_res, self.fig) + + self.assertEqual( + trace_list, [self.fig.data[i] for i in expected_inds]) + + def test_select_by_type(self): + self.assert_select_traces( + [0, 2], selector={'type': 'scatter'}, test_no_grid=True) + self.assert_select_traces( + [1], selector={'type': 'bar'}, test_no_grid=True) + self.assert_select_traces( + [3], selector={'type': 'heatmap'}, test_no_grid=True) + self.assert_select_traces( + [4, 5], selector={'type': 'scatter3d'}, test_no_grid=True) + self.assert_select_traces( + [6, 7], selector={'type': 'scatterpolar'}, test_no_grid=True) + self.assert_select_traces( + [8], selector={'type': 'parcoords'}, test_no_grid=True) + self.assert_select_traces( + [], selector={'type': 'pie'}, test_no_grid=True) + + def test_select_by_grid(self): + self.assert_select_traces([0, 1], row=1, col=1) + self.assert_select_traces([2, 3], row=2, col=1) + self.assert_select_traces([4, 5], row=1, col=2) + self.assert_select_traces([6, 7], row=2, col=2) + self.assert_select_traces([8], row=3, col=1) + + def test_select_by_property_across_trace_types(self): + self.assert_select_traces( + [0, 4, 6], selector={'mode': 'markers'}, test_no_grid=True) + self.assert_select_traces( + [2, 5, 7], selector={'mode': 'lines'}, test_no_grid=True) + self.assert_select_traces( + [0, 4], + selector={'marker': {'color': 'green', 'size': 10}}, + test_no_grid=True) + + # Several traces have 'marker.color' == 'green', but they all have + # additional marker properties so there should be no exact match. + self.assert_select_traces( + [], selector={'marker': {'color': 'green'}}, test_no_grid=True) + self.assert_select_traces( + [0, 4, 6], selector={'marker.color': 'green'}, test_no_grid=True) + self.assert_select_traces( + [2, 5, 8], selector={'line.color': 'purple'}, test_no_grid=True) + + def test_select_property_and_grid(self): + # (1, 1) + self.assert_select_traces( + [0], selector={'mode': 'markers'}, row=1, col=1) + self.assert_select_traces( + [1], selector={'type': 'bar'}, row=1, col=1) + + # (2, 1) + self.assert_select_traces( + [2], selector={'mode': 'lines'}, row=2, col=1) + + # (1, 2) + self.assert_select_traces( + [4], selector={'marker.color': 'green'}, row=1, col=2) + + # Valid row/col and valid selector but the intersection is empty + self.assert_select_traces( + [], selector={'type': 'markers'}, row=3, col=1) + + def test_for_each_trace_lowercase_names(self): + # Names are all uppercase to start + original_names = [t.name for t in self.fig.data] + self.assertTrue([str.isupper(n) for n in original_names]) + + # Lower case names + result_fig = self.fig.for_each_trace( + lambda t: t.update(name=t.name.lower()) + ) + + # Check chaning + self.assertIs(result_fig, self.fig) + + # Check that names were altered + self.assertTrue( + all([t.name == n.lower() + for t, n in zip(result_fig.data, original_names)])) + + # test update_traces + # ------------------ + def assert_update_traces( + self, patch, expected_inds, selector=None, row=None, col=None + ): + # Save off original figure + fig_orig = copy.deepcopy(self.fig) + for trace1, trace2 in zip(fig_orig.data, self.fig.data): + trace1.uid = trace2.uid + + # Perform update + update_res = self.fig.update_traces( + patch, selector=selector, row=row, col=col + ) + + # Check chaining support + self.assertIs(update_res, self.fig) + + # Check resulting traces + for i, (t_orig, t) in enumerate(zip(fig_orig.data, self.fig.data)): + if i in expected_inds: + # Check that traces are initially equal + self.assertNotEqual(t_orig, t) + + # Check that traces are equal after update + t_orig.update(patch) + + # Check that traces are equal + self.assertEqual(t_orig, t) + + def test_update_traces_by_type(self): + self.assert_update_traces( + {'visible': 'legendonly'}, + [0, 2], + selector={'type': 'scatter'} + ) + + self.assert_update_traces( + {'visible': 'legendonly'}, + [1], + selector={'type': 'bar'}, + ) + + self.assert_update_traces( + {'colorscale': 'Viridis'}, + [3], + selector={'type': 'heatmap'} + ) + + # Nest dictionaries + self.assert_update_traces( + {'marker': {'line': {'color': 'yellow'}}}, + [4, 5], + selector={'type': 'scatter3d'} + ) + + # dot syntax + self.assert_update_traces( + {'marker.line.color': 'cyan'}, + [4, 5], + selector={'type': 'scatter3d'} + ) + + # underscore syntax + self.assert_update_traces( + dict(marker_line_color='pink'), + [4, 5], + selector={'type': 'scatter3d'} + ) + + self.assert_update_traces( + {'line': {'dash': 'dot'}}, + [6, 7], + selector={'type': 'scatterpolar'} + ) + + # Nested dictionaries + self.assert_update_traces( + {'dimensions': {1: {'label': 'Dimension 1'}}}, + [8], + selector={'type': 'parcoords'} + ) + + # Dot syntax + self.assert_update_traces( + {'dimensions[1].label': 'Dimension A'}, + [8], + selector={'type': 'parcoords'} + ) + + # underscore syntax + # Dot syntax + self.assert_update_traces( + dict(dimensions_1_label='Dimension X'), + [8], + selector={'type': 'parcoords'} + ) + + self.assert_update_traces( + {'hoverinfo': 'label+percent'}, + [], selector={'type': 'pie'} + )