Skip to content

traces selection/update methods and "magic underscore" support #1534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions codegen/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class {fig_classname}({base_classname}):\n""")

buffer.write(f"""
def __init__(self, data=None, layout=None,
frames=None, skip_invalid=False):
frames=None, skip_invalid=False, **kwargs):
\"\"\"
Create a new {fig_classname} instance

Expand All @@ -95,7 +95,8 @@ def __init__(self, data=None, layout=None,
is invalid AND skip_invalid is False
\"\"\"
super({fig_classname} ,self).__init__(data, layout,
frames, skip_invalid)
frames, skip_invalid,
**kwargs)
""")

# ### add_trace methods for each trace type ###
Expand Down
255 changes: 233 additions & 22 deletions plotly/basedatatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from contextlib import contextmanager
from copy import deepcopy, copy

from plotly.subplots import _set_trace_grid_reference, _get_grid_subplot
from plotly.subplots import (
_set_trace_grid_reference,
_get_grid_subplot,
_get_subplot_ref_for_trace,
_validate_v4_subplots)
from .optional_imports import get_module

from _plotly_utils.basevalidators import (
Expand All @@ -33,13 +37,25 @@ class BaseFigure(object):
"""
_bracket_re = re.compile('^(.*)\[(\d+)\]$')

_valid_underscore_properties = {
'error_x': 'error-x',
'error_y': 'error-y',
'error_z': 'error-z',
'copy_xstyle': 'copy-xstyle',
'copy_ystyle': 'copy-ystyle',
'copy_zstyle': 'copy-zstyle',
'paper_bgcolor': 'paper-bgcolor',
'plot_bgcolor': 'plot-bgcolor'
}

# Constructor
# -----------
def __init__(self,
data=None,
layout_plotly=None,
frames=None,
skip_invalid=False):
skip_invalid=False,
**kwargs):
"""
Construct a BaseFigure object

Expand Down Expand Up @@ -247,6 +263,14 @@ class is a subclass of both BaseFigure and widgets.DOMWidget.
# ### Check for default template ###
self._initialize_layout_template()

# Process kwargs
# --------------
for k, v in kwargs.items():
if k in self:
self[k] = v
elif not skip_invalid:
raise TypeError('invalid Figure property: {}'.format(k))

# Magic Methods
# -------------
def __reduce__(self):
Expand Down Expand Up @@ -356,7 +380,13 @@ def __iter__(self):
return iter(('data', 'layout', 'frames'))

def __contains__(self, prop):
return prop in ('data', 'layout', 'frames')
prop = BaseFigure._str_to_dict_path(prop)
if prop[0] not in ('data', 'layout', 'frames'):
return False
elif len(prop) == 1:
return True
else:
return prop[1:] in self[prop[0]]

def __eq__(self, other):
if not isinstance(other, BaseFigure):
Expand Down Expand Up @@ -447,17 +477,21 @@ def update(self, dict1=None, **kwargs):
for d in [dict1, kwargs]:
if d:
for k, v in d.items():
if self[k] == ():
update_target = self[k]
if update_target == ():
# existing data or frames property is empty
# In this case we accept the v as is.
if k == 'data':
self.add_traces(v)
else:
# Accept v
self[k] = v
else:
elif (isinstance(update_target, BasePlotlyType) or
(isinstance(update_target, tuple) and
isinstance(update_target[0], BasePlotlyType))):
BaseFigure._perform_update(self[k], v)

else:
self[k] = v
return self

# Data
Expand Down Expand Up @@ -604,6 +638,140 @@ def data(self, new_data):
for trace_ind, trace in enumerate(self._data_objs):
trace._trace_ind = trace_ind

def select_traces(self, selector=None, row=None, col=None):
"""
Select traces from a particular subplot cell and/or traces
that satisfy custom selection criteria.

Parameters
----------
selector: dict or None (default None)
Dict to use as selection criteria.
Traces will be selected if they contain properties corresponding
to all of the dictionary's keys, with values that exactly match
the supplied values. If None (the default), all traces are
selected.
row, col: int or None (default None)
Subplot row and column index of traces to select.
To select traces by row and column, the Figure must have been
created using plotly.subplots.make_subplots. If None
(the default), all traces are selected.

Returns
-------
generator
Generator that iterates through all of the traces that satisfy
all of the specified selection criteria
"""
if not selector:
selector = {}

if row is not None and col is not None:
_validate_v4_subplots('select_traces')
grid_ref = self._validate_get_grid_ref()
grid_subplot_ref = grid_ref[row-1][col-1]
filter_by_subplot = True
else:
filter_by_subplot = False
grid_subplot_ref = None

return self._perform_select_traces(
filter_by_subplot, grid_subplot_ref, selector)

def _perform_select_traces(
self, filter_by_subplot, grid_subplot_ref, selector):

def select_eq(obj1, obj2):
try:
obj1 = obj1.to_plotly_json()
except Exception:
pass
try:
obj2 = obj2.to_plotly_json()
except Exception:
pass

return BasePlotlyType._vals_equal(obj1, obj2)

for trace in self.data:
# Filter by subplot
if filter_by_subplot:
trace_subplot_ref = _get_subplot_ref_for_trace(trace)
if grid_subplot_ref != trace_subplot_ref:
continue

# Filter by selector
if not all(
k in trace and select_eq(trace[k], selector[k])
for k in selector):
continue

yield trace

def for_each_trace(self, fn, selector=None, row=None, col=None):
"""
Apply a function to all traces that satisfy the specified selection
criteria

Parameters
----------
fn:
Function that inputs a single trace object.
selector: dict or None (default None)
Dict to use as selection criteria.
Traces will be selected if they contain properties corresponding
to all of the dictionary's keys, with values that exactly match
the supplied values. If None (the default), all traces are
selected.
row, col: int or None (default None)
Subplot row and column index of traces to select.
To select traces by row and column, the Figure must have been
created using plotly.subplots.make_subplots. If None
(the default), all traces are selected.

Returns
-------
self
Returns the Figure object that the method was called on
"""
for trace in self.select_traces(selector=selector, row=row, col=col):
fn(trace)

return self

def update_traces(self, patch, selector=None, row=None, col=None):
"""
Perform a property update operation on all traces that satisfy the
specified selection criteria

Parameters
----------
patch: dict
Dictionary of property updates to be applied to all traces that
satisfy the selection criteria.
fn:
Function that inputs a single trace object.
selector: dict or None (default None)
Dict to use as selection criteria.
Traces will be selected if they contain properties corresponding
to all of the dictionary's keys, with values that exactly match
the supplied values. If None (the default), all traces are
selected.
row, col: int or None (default None)
Subplot row and column index of traces to select.
To select traces by row and column, the Figure must have been
created using plotly.subplots.make_subplots. If None
(the default), all traces are selected.

Returns
-------
self
Returns the Figure object that the method was called on
"""
for trace in self.select_traces(selector=selector, row=row, col=col):
trace.update(patch)
return self

# Restyle
# -------
def plotly_restyle(self, restyle_data, trace_indexes=None, **kwargs):
Expand Down Expand Up @@ -822,18 +990,20 @@ def _str_to_dict_path(key_path_str):
"""
if isinstance(key_path_str, string_types) and \
'.' not in key_path_str and \
'[' not in key_path_str:
'[' not in key_path_str and \
'_' not in key_path_str:
# Fast path for common case that avoids regular expressions
return (key_path_str,)
elif isinstance(key_path_str, tuple):
# Nothing to do
return key_path_str
else:
# Split string on periods. e.g. 'foo.bar[1]' -> ['foo', 'bar[1]']
# Split string on periods.
# e.g. 'foo.bar_baz[1]' -> ['foo', 'bar_baz[1]']
key_path = key_path_str.split('.')

# Split out bracket indexes.
# e.g. ['foo', 'bar[1]'] -> ['foo', 'bar', '1']
# e.g. ['foo', 'bar_baz[1]'] -> ['foo', 'bar_baz', '1']
key_path2 = []
for key in key_path:
match = BaseFigure._bracket_re.match(key)
Expand All @@ -842,15 +1012,39 @@ def _str_to_dict_path(key_path_str):
else:
key_path2.append(key)

# Split out underscore
# e.g. ['foo', 'bar_baz', '1'] -> ['foo', 'bar', 'baz', '1']
key_path3 = []
underscore_props = BaseFigure._valid_underscore_properties
for key in key_path2:
if '_' in key[1:]:
# For valid properties that contain underscores (error_x)
# replace the underscores with hyphens to protect them
# from being split up
for under_prop, hyphen_prop in underscore_props.items():
key = key.replace(under_prop, hyphen_prop)

# Split key on underscores
key = key.split('_')

# Replace hyphens with underscores to restore properties
# that include underscores
for i in range(len(key)):
key[i] = key[i].replace('-', '_')

key_path3.extend(key)
else:
key_path3.append(key)

# Convert elements to ints if possible.
# e.g. ['foo', 'bar', '0'] -> ['foo', 'bar', 0]
for i in range(len(key_path2)):
for i in range(len(key_path3)):
try:
key_path2[i] = int(key_path2[i])
key_path3[i] = int(key_path3[i])
except ValueError as _:
pass

return tuple(key_path2)
return tuple(key_path3)

@staticmethod
def _set_in(d, key_path_str, v):
Expand Down Expand Up @@ -1235,13 +1429,8 @@ def append_trace(self, trace, row, col):
self.add_trace(trace=trace, row=row, col=col)

def _set_trace_grid_position(self, trace, row, col):
try:
grid_ref = self._grid_ref
except AttributeError:
raise Exception("In order to reference traces by row and column, "
"you must first use "
"plotly.tools.make_subplots "
"to create the figure with a subplot grid.")
grid_ref = self._validate_get_grid_ref()

from _plotly_future_ import _future_flags
if 'v4_subplots' in _future_flags:
return _set_trace_grid_reference(
Expand Down Expand Up @@ -1277,6 +1466,18 @@ def _set_trace_grid_position(self, trace, row, col):
trace['xaxis'] = ref[0]
trace['yaxis'] = ref[1]

def _validate_get_grid_ref(self):
try:
grid_ref = self._grid_ref
if grid_ref is None:
raise AttributeError('_grid_ref')
except AttributeError:
raise Exception("In order to reference traces by row and column, "
"you must first use "
"plotly.tools.make_subplots "
"to create the figure with a subplot grid.")
return grid_ref

def get_subplot(self, row, col):
"""
Return an object representing the subplot at the specified row
Expand Down Expand Up @@ -2429,8 +2630,16 @@ def _process_kwargs(self, **kwargs):
"""
Process any extra kwargs that are not predefined as constructor params
"""
if not self._skip_invalid:
self._raise_on_invalid_property_error(*kwargs.keys())
invalid_kwargs = {}
for k, v in kwargs.items():
if k in self:
# e.g. underscore kwargs like marker_line_color
self[k] = v
else:
invalid_kwargs[k] = v

if invalid_kwargs and not self._skip_invalid:
self._raise_on_invalid_property_error(*invalid_kwargs.keys())

@property
def plotly_name(self):
Expand Down Expand Up @@ -2675,7 +2884,9 @@ def _get_prop_validator(self, prop):
plotly_obj = self[prop_path[:-1]]
prop = prop_path[-1]
else:
plotly_obj = self
prop_path = BaseFigure._str_to_dict_path(prop)
plotly_obj = self[prop_path[:-1]]
prop = prop_path[-1]

# Return validator
# ----------------
Expand Down
Loading