From 7c73aefd9f69b63e5016732de42c15399cf4b3d2 Mon Sep 17 00:00:00 2001 From: Nicholas Esterer Date: Thu, 15 Oct 2020 17:35:23 -0400 Subject: [PATCH] select_traces selector argument can now be a function --- .../python/plotly/plotly/basedatatypes.py | 17 +++++++-- .../test_update_objects/test_update_traces.py | 37 +++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index 7bc0e66f56d..369ea0d9723 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -804,10 +804,19 @@ def _perform_select_traces(self, filter_by_subplot, grid_subplot_refs, selector) continue # Filter by selector - if not self._selector_matches(trace, selector): - continue - - yield trace + # If selector is a dict, call self._selector_matches + if type(selector) == type(dict()): + trace_matches = self._selector_matches(trace, selector) + # If selector is a function, call it with the trace as the argument + elif type(selector) == type(lambda x: True): + trace_matches = selector(trace) + else: + raise TypeError( + "selector must be dict or a function " + "accepting a trace returning a boolean." + ) + if trace_matches: + yield trace @staticmethod def _selector_matches(obj, selector): diff --git a/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_traces.py b/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_traces.py index b1c8487ceae..ccfe50382e0 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_traces.py +++ b/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_traces.py @@ -225,6 +225,43 @@ def test_select_property_and_grid(self): # Valid row/col and valid selector but the intersection is empty self.assert_select_traces([], selector={"type": "markers"}, row=3, col=1) + def test_select_with_function(self): + def _check_trace_key(k, v): + def f(t): + try: + return t[k] == v + except LookupError: + return False + + return f + + # (1, 1) + self.assert_select_traces( + [0], selector=_check_trace_key("mode", "markers"), row=1, col=1 + ) + self.assert_select_traces( + [1], selector=_check_trace_key("type", "bar"), row=1, col=1 + ) + + # (2, 1) + self.assert_select_traces( + [2, 9], selector=_check_trace_key("mode", "lines"), row=2, col=1 + ) + + # (1, 2) + self.assert_select_traces( + [4], selector=_check_trace_key("marker.color", "green"), row=1, col=2 + ) + + # Valid row/col and valid selector but the intersection is empty + self.assert_select_traces( + [], selector=_check_trace_key("type", "markers"), row=3, col=1 + ) + + def test_select_traces_type_error(self): + with self.assertRaises(TypeError): + self.assert_select_traces([0], selector=123, row=1, col=1) + def test_for_each_trace_lowercase_names(self): # Names are all uppercase to start original_names = [t.name for t in self.fig.data]