Skip to content

Commit 7c73aef

Browse files
select_traces selector argument can now be a function
1 parent 9eded54 commit 7c73aef

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

Diff for: packages/python/plotly/plotly/basedatatypes.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -804,10 +804,19 @@ def _perform_select_traces(self, filter_by_subplot, grid_subplot_refs, selector)
804804
continue
805805

806806
# Filter by selector
807-
if not self._selector_matches(trace, selector):
808-
continue
809-
810-
yield trace
807+
# If selector is a dict, call self._selector_matches
808+
if type(selector) == type(dict()):
809+
trace_matches = self._selector_matches(trace, selector)
810+
# If selector is a function, call it with the trace as the argument
811+
elif type(selector) == type(lambda x: True):
812+
trace_matches = selector(trace)
813+
else:
814+
raise TypeError(
815+
"selector must be dict or a function "
816+
"accepting a trace returning a boolean."
817+
)
818+
if trace_matches:
819+
yield trace
811820

812821
@staticmethod
813822
def _selector_matches(obj, selector):

Diff for: packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_traces.py

+37
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,43 @@ def test_select_property_and_grid(self):
225225
# Valid row/col and valid selector but the intersection is empty
226226
self.assert_select_traces([], selector={"type": "markers"}, row=3, col=1)
227227

228+
def test_select_with_function(self):
229+
def _check_trace_key(k, v):
230+
def f(t):
231+
try:
232+
return t[k] == v
233+
except LookupError:
234+
return False
235+
236+
return f
237+
238+
# (1, 1)
239+
self.assert_select_traces(
240+
[0], selector=_check_trace_key("mode", "markers"), row=1, col=1
241+
)
242+
self.assert_select_traces(
243+
[1], selector=_check_trace_key("type", "bar"), row=1, col=1
244+
)
245+
246+
# (2, 1)
247+
self.assert_select_traces(
248+
[2, 9], selector=_check_trace_key("mode", "lines"), row=2, col=1
249+
)
250+
251+
# (1, 2)
252+
self.assert_select_traces(
253+
[4], selector=_check_trace_key("marker.color", "green"), row=1, col=2
254+
)
255+
256+
# Valid row/col and valid selector but the intersection is empty
257+
self.assert_select_traces(
258+
[], selector=_check_trace_key("type", "markers"), row=3, col=1
259+
)
260+
261+
def test_select_traces_type_error(self):
262+
with self.assertRaises(TypeError):
263+
self.assert_select_traces([0], selector=123, row=1, col=1)
264+
228265
def test_for_each_trace_lowercase_names(self):
229266
# Names are all uppercase to start
230267
original_names = [t.name for t in self.fig.data]

0 commit comments

Comments
 (0)