Skip to content

Commit 6ced2d9

Browse files
authored
traces selection/update methods and "magic underscore" support (#1534)
* Added select_traces, for_each_trace, and update_traces methods * Add tests * Add support for specifying .update paths with dot strings. E.g. self.layout.update({'updatemenus[1].buttons[2].method': 'restyle'}) * Added "magic underscore" support for property assignment and .update() * Added "magic underscore" support for graph_objs constructors * Added "magic underscore" support for Figure constructors * Test "magic underscore" support for Figure add_* trace builders
1 parent 8162927 commit 6ced2d9

File tree

12 files changed

+792
-62
lines changed

12 files changed

+792
-62
lines changed

Diff for: codegen/figure.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class {fig_classname}({base_classname}):\n""")
6868

6969
buffer.write(f"""
7070
def __init__(self, data=None, layout=None,
71-
frames=None, skip_invalid=False):
71+
frames=None, skip_invalid=False, **kwargs):
7272
\"\"\"
7373
Create a new {fig_classname} instance
7474
@@ -95,7 +95,8 @@ def __init__(self, data=None, layout=None,
9595
is invalid AND skip_invalid is False
9696
\"\"\"
9797
super({fig_classname} ,self).__init__(data, layout,
98-
frames, skip_invalid)
98+
frames, skip_invalid,
99+
**kwargs)
99100
""")
100101

101102
# ### add_trace methods for each trace type ###

Diff for: plotly/basedatatypes.py

+233-22
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from contextlib import contextmanager
1010
from copy import deepcopy, copy
1111

12-
from plotly.subplots import _set_trace_grid_reference, _get_grid_subplot
12+
from plotly.subplots import (
13+
_set_trace_grid_reference,
14+
_get_grid_subplot,
15+
_get_subplot_ref_for_trace,
16+
_validate_v4_subplots)
1317
from .optional_imports import get_module
1418

1519
from _plotly_utils.basevalidators import (
@@ -33,13 +37,25 @@ class BaseFigure(object):
3337
"""
3438
_bracket_re = re.compile('^(.*)\[(\d+)\]$')
3539

40+
_valid_underscore_properties = {
41+
'error_x': 'error-x',
42+
'error_y': 'error-y',
43+
'error_z': 'error-z',
44+
'copy_xstyle': 'copy-xstyle',
45+
'copy_ystyle': 'copy-ystyle',
46+
'copy_zstyle': 'copy-zstyle',
47+
'paper_bgcolor': 'paper-bgcolor',
48+
'plot_bgcolor': 'plot-bgcolor'
49+
}
50+
3651
# Constructor
3752
# -----------
3853
def __init__(self,
3954
data=None,
4055
layout_plotly=None,
4156
frames=None,
42-
skip_invalid=False):
57+
skip_invalid=False,
58+
**kwargs):
4359
"""
4460
Construct a BaseFigure object
4561
@@ -247,6 +263,14 @@ class is a subclass of both BaseFigure and widgets.DOMWidget.
247263
# ### Check for default template ###
248264
self._initialize_layout_template()
249265

266+
# Process kwargs
267+
# --------------
268+
for k, v in kwargs.items():
269+
if k in self:
270+
self[k] = v
271+
elif not skip_invalid:
272+
raise TypeError('invalid Figure property: {}'.format(k))
273+
250274
# Magic Methods
251275
# -------------
252276
def __reduce__(self):
@@ -356,7 +380,13 @@ def __iter__(self):
356380
return iter(('data', 'layout', 'frames'))
357381

358382
def __contains__(self, prop):
359-
return prop in ('data', 'layout', 'frames')
383+
prop = BaseFigure._str_to_dict_path(prop)
384+
if prop[0] not in ('data', 'layout', 'frames'):
385+
return False
386+
elif len(prop) == 1:
387+
return True
388+
else:
389+
return prop[1:] in self[prop[0]]
360390

361391
def __eq__(self, other):
362392
if not isinstance(other, BaseFigure):
@@ -447,17 +477,21 @@ def update(self, dict1=None, **kwargs):
447477
for d in [dict1, kwargs]:
448478
if d:
449479
for k, v in d.items():
450-
if self[k] == ():
480+
update_target = self[k]
481+
if update_target == ():
451482
# existing data or frames property is empty
452483
# In this case we accept the v as is.
453484
if k == 'data':
454485
self.add_traces(v)
455486
else:
456487
# Accept v
457488
self[k] = v
458-
else:
489+
elif (isinstance(update_target, BasePlotlyType) or
490+
(isinstance(update_target, tuple) and
491+
isinstance(update_target[0], BasePlotlyType))):
459492
BaseFigure._perform_update(self[k], v)
460-
493+
else:
494+
self[k] = v
461495
return self
462496

463497
# Data
@@ -604,6 +638,140 @@ def data(self, new_data):
604638
for trace_ind, trace in enumerate(self._data_objs):
605639
trace._trace_ind = trace_ind
606640

641+
def select_traces(self, selector=None, row=None, col=None):
642+
"""
643+
Select traces from a particular subplot cell and/or traces
644+
that satisfy custom selection criteria.
645+
646+
Parameters
647+
----------
648+
selector: dict or None (default None)
649+
Dict to use as selection criteria.
650+
Traces will be selected if they contain properties corresponding
651+
to all of the dictionary's keys, with values that exactly match
652+
the supplied values. If None (the default), all traces are
653+
selected.
654+
row, col: int or None (default None)
655+
Subplot row and column index of traces to select.
656+
To select traces by row and column, the Figure must have been
657+
created using plotly.subplots.make_subplots. If None
658+
(the default), all traces are selected.
659+
660+
Returns
661+
-------
662+
generator
663+
Generator that iterates through all of the traces that satisfy
664+
all of the specified selection criteria
665+
"""
666+
if not selector:
667+
selector = {}
668+
669+
if row is not None and col is not None:
670+
_validate_v4_subplots('select_traces')
671+
grid_ref = self._validate_get_grid_ref()
672+
grid_subplot_ref = grid_ref[row-1][col-1]
673+
filter_by_subplot = True
674+
else:
675+
filter_by_subplot = False
676+
grid_subplot_ref = None
677+
678+
return self._perform_select_traces(
679+
filter_by_subplot, grid_subplot_ref, selector)
680+
681+
def _perform_select_traces(
682+
self, filter_by_subplot, grid_subplot_ref, selector):
683+
684+
def select_eq(obj1, obj2):
685+
try:
686+
obj1 = obj1.to_plotly_json()
687+
except Exception:
688+
pass
689+
try:
690+
obj2 = obj2.to_plotly_json()
691+
except Exception:
692+
pass
693+
694+
return BasePlotlyType._vals_equal(obj1, obj2)
695+
696+
for trace in self.data:
697+
# Filter by subplot
698+
if filter_by_subplot:
699+
trace_subplot_ref = _get_subplot_ref_for_trace(trace)
700+
if grid_subplot_ref != trace_subplot_ref:
701+
continue
702+
703+
# Filter by selector
704+
if not all(
705+
k in trace and select_eq(trace[k], selector[k])
706+
for k in selector):
707+
continue
708+
709+
yield trace
710+
711+
def for_each_trace(self, fn, selector=None, row=None, col=None):
712+
"""
713+
Apply a function to all traces that satisfy the specified selection
714+
criteria
715+
716+
Parameters
717+
----------
718+
fn:
719+
Function that inputs a single trace object.
720+
selector: dict or None (default None)
721+
Dict to use as selection criteria.
722+
Traces will be selected if they contain properties corresponding
723+
to all of the dictionary's keys, with values that exactly match
724+
the supplied values. If None (the default), all traces are
725+
selected.
726+
row, col: int or None (default None)
727+
Subplot row and column index of traces to select.
728+
To select traces by row and column, the Figure must have been
729+
created using plotly.subplots.make_subplots. If None
730+
(the default), all traces are selected.
731+
732+
Returns
733+
-------
734+
self
735+
Returns the Figure object that the method was called on
736+
"""
737+
for trace in self.select_traces(selector=selector, row=row, col=col):
738+
fn(trace)
739+
740+
return self
741+
742+
def update_traces(self, patch, selector=None, row=None, col=None):
743+
"""
744+
Perform a property update operation on all traces that satisfy the
745+
specified selection criteria
746+
747+
Parameters
748+
----------
749+
patch: dict
750+
Dictionary of property updates to be applied to all traces that
751+
satisfy the selection criteria.
752+
fn:
753+
Function that inputs a single trace object.
754+
selector: dict or None (default None)
755+
Dict to use as selection criteria.
756+
Traces will be selected if they contain properties corresponding
757+
to all of the dictionary's keys, with values that exactly match
758+
the supplied values. If None (the default), all traces are
759+
selected.
760+
row, col: int or None (default None)
761+
Subplot row and column index of traces to select.
762+
To select traces by row and column, the Figure must have been
763+
created using plotly.subplots.make_subplots. If None
764+
(the default), all traces are selected.
765+
766+
Returns
767+
-------
768+
self
769+
Returns the Figure object that the method was called on
770+
"""
771+
for trace in self.select_traces(selector=selector, row=row, col=col):
772+
trace.update(patch)
773+
return self
774+
607775
# Restyle
608776
# -------
609777
def plotly_restyle(self, restyle_data, trace_indexes=None, **kwargs):
@@ -822,18 +990,20 @@ def _str_to_dict_path(key_path_str):
822990
"""
823991
if isinstance(key_path_str, string_types) and \
824992
'.' not in key_path_str and \
825-
'[' not in key_path_str:
993+
'[' not in key_path_str and \
994+
'_' not in key_path_str:
826995
# Fast path for common case that avoids regular expressions
827996
return (key_path_str,)
828997
elif isinstance(key_path_str, tuple):
829998
# Nothing to do
830999
return key_path_str
8311000
else:
832-
# Split string on periods. e.g. 'foo.bar[1]' -> ['foo', 'bar[1]']
1001+
# Split string on periods.
1002+
# e.g. 'foo.bar_baz[1]' -> ['foo', 'bar_baz[1]']
8331003
key_path = key_path_str.split('.')
8341004

8351005
# Split out bracket indexes.
836-
# e.g. ['foo', 'bar[1]'] -> ['foo', 'bar', '1']
1006+
# e.g. ['foo', 'bar_baz[1]'] -> ['foo', 'bar_baz', '1']
8371007
key_path2 = []
8381008
for key in key_path:
8391009
match = BaseFigure._bracket_re.match(key)
@@ -842,15 +1012,39 @@ def _str_to_dict_path(key_path_str):
8421012
else:
8431013
key_path2.append(key)
8441014

1015+
# Split out underscore
1016+
# e.g. ['foo', 'bar_baz', '1'] -> ['foo', 'bar', 'baz', '1']
1017+
key_path3 = []
1018+
underscore_props = BaseFigure._valid_underscore_properties
1019+
for key in key_path2:
1020+
if '_' in key[1:]:
1021+
# For valid properties that contain underscores (error_x)
1022+
# replace the underscores with hyphens to protect them
1023+
# from being split up
1024+
for under_prop, hyphen_prop in underscore_props.items():
1025+
key = key.replace(under_prop, hyphen_prop)
1026+
1027+
# Split key on underscores
1028+
key = key.split('_')
1029+
1030+
# Replace hyphens with underscores to restore properties
1031+
# that include underscores
1032+
for i in range(len(key)):
1033+
key[i] = key[i].replace('-', '_')
1034+
1035+
key_path3.extend(key)
1036+
else:
1037+
key_path3.append(key)
1038+
8451039
# Convert elements to ints if possible.
8461040
# e.g. ['foo', 'bar', '0'] -> ['foo', 'bar', 0]
847-
for i in range(len(key_path2)):
1041+
for i in range(len(key_path3)):
8481042
try:
849-
key_path2[i] = int(key_path2[i])
1043+
key_path3[i] = int(key_path3[i])
8501044
except ValueError as _:
8511045
pass
8521046

853-
return tuple(key_path2)
1047+
return tuple(key_path3)
8541048

8551049
@staticmethod
8561050
def _set_in(d, key_path_str, v):
@@ -1235,13 +1429,8 @@ def append_trace(self, trace, row, col):
12351429
self.add_trace(trace=trace, row=row, col=col)
12361430

12371431
def _set_trace_grid_position(self, trace, row, col):
1238-
try:
1239-
grid_ref = self._grid_ref
1240-
except AttributeError:
1241-
raise Exception("In order to reference traces by row and column, "
1242-
"you must first use "
1243-
"plotly.tools.make_subplots "
1244-
"to create the figure with a subplot grid.")
1432+
grid_ref = self._validate_get_grid_ref()
1433+
12451434
from _plotly_future_ import _future_flags
12461435
if 'v4_subplots' in _future_flags:
12471436
return _set_trace_grid_reference(
@@ -1277,6 +1466,18 @@ def _set_trace_grid_position(self, trace, row, col):
12771466
trace['xaxis'] = ref[0]
12781467
trace['yaxis'] = ref[1]
12791468

1469+
def _validate_get_grid_ref(self):
1470+
try:
1471+
grid_ref = self._grid_ref
1472+
if grid_ref is None:
1473+
raise AttributeError('_grid_ref')
1474+
except AttributeError:
1475+
raise Exception("In order to reference traces by row and column, "
1476+
"you must first use "
1477+
"plotly.tools.make_subplots "
1478+
"to create the figure with a subplot grid.")
1479+
return grid_ref
1480+
12801481
def get_subplot(self, row, col):
12811482
"""
12821483
Return an object representing the subplot at the specified row
@@ -2429,8 +2630,16 @@ def _process_kwargs(self, **kwargs):
24292630
"""
24302631
Process any extra kwargs that are not predefined as constructor params
24312632
"""
2432-
if not self._skip_invalid:
2433-
self._raise_on_invalid_property_error(*kwargs.keys())
2633+
invalid_kwargs = {}
2634+
for k, v in kwargs.items():
2635+
if k in self:
2636+
# e.g. underscore kwargs like marker_line_color
2637+
self[k] = v
2638+
else:
2639+
invalid_kwargs[k] = v
2640+
2641+
if invalid_kwargs and not self._skip_invalid:
2642+
self._raise_on_invalid_property_error(*invalid_kwargs.keys())
24342643

24352644
@property
24362645
def plotly_name(self):
@@ -2675,7 +2884,9 @@ def _get_prop_validator(self, prop):
26752884
plotly_obj = self[prop_path[:-1]]
26762885
prop = prop_path[-1]
26772886
else:
2678-
plotly_obj = self
2887+
prop_path = BaseFigure._str_to_dict_path(prop)
2888+
plotly_obj = self[prop_path[:-1]]
2889+
prop = prop_path[-1]
26792890

26802891
# Return validator
26812892
# ----------------

0 commit comments

Comments
 (0)