diff --git a/CHANGELOG.md b/CHANGELOG.md index d135720aa1e..4753c8df0ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Added +- For `add_trace`, `add_shape`, `add_annotation` and `add_layout_image`, the `row` and/or `col` argument now also accept the string `"all"`. `row="all"` adds the object to all the subplot rows and `col="all"` adds the object to all the subplot columns. + +- Shapes that reference the plot axes in one dimension and the data in another dimension can be added with the new `add_hline`, `add_vline`, `add_hrect`, `add_vrect` functions, which also support the `row="all"` and `col="all"` arguments. + +- The `add_trace`, `add_shape`, `add_annotation`, `add_layout_image`, `add_hline`, `add_vline`, `add_hrect`, `add_vrect` functions accept an argument `exclude_empty_subplots` which if `True`, only adds the object to subplots already containing traces or layout objects. This is useful in conjunction with the `row="all"` and `col="all"` arguments. + - For all `go.Figure` functions accepting a selector argument (e.g., `select_traces`), this argument can now also be a function which is passed each relevant graph object (in the case of `select_traces`, it is passed every trace in the figure). For graph objects where this function returns true, the graph object is included in the selection. ### Updated diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index 8cb0b07a05d..f2033cf359a 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -1302,7 +1302,9 @@ def _add_annotation_like( # if exclude_empty_subplots is True, check to see if subplot is # empty and return if it is if exclude_empty_subplots and ( - not self._subplot_contains_trace(xref, yref) + not self._subplot_not_empty( + xref, yref, selector=bool(exclude_empty_subplots) + ) ): return self # in case the user specified they wanted an axis to refer to the @@ -1993,8 +1995,8 @@ def add_traces( if exclude_empty_subplots: data = list( filter( - lambda trace: self._subplot_contains_trace( - trace["xaxis"], trace["yaxis"] + lambda trace: self._subplot_not_empty( + trace["xaxis"], trace["yaxis"], bool(exclude_empty_subplots) ), data, ) @@ -3873,19 +3875,56 @@ def _has_subplots(self): single plot and so this returns False. """ return self._grid_ref is not None - def _subplot_contains_trace(self, xref, yref): - return any( - t == (xref, yref) - for t in [ - # if a trace exists but has no xaxis or yaxis keys, then it - # is plotted with xaxis 'x' and yaxis 'y' - ( - "x" if d["xaxis"] is None else d["xaxis"], - "y" if d["yaxis"] is None else d["yaxis"], + def _subplot_not_empty(self, xref, yref, selector="all"): + """ + xref: string representing the axis. Objects in the plot will be checked + for this xref (for layout objects) or xaxis (for traces) to + determine if they lie in a certain subplot. + yref: string representing the axis. Objects in the plot will be checked + for this yref (for layout objects) or yaxis (for traces) to + determine if they lie in a certain subplot. + selector: can be "all" or an iterable containing some combination of + "traces", "shapes", "annotations", "images". Only the presence + of objects specified in selector will be checked. So if + ["traces","shapes"] is passed then a plot we be considered + non-empty if it contains traces or shapes. If + bool(selector) returns False, no checking is performed and + this function returns True. If selector is True, it is + converted to "all". + """ + if not selector: + # If nothing to select was specified then a subplot is always deemed non-empty + return True + if selector == True: + selector = "all" + if selector == "all": + selector = ["traces", "shapes", "annotations", "images"] + ret = False + for s in selector: + if s == "traces": + obj = self.data + xaxiskw = "xaxis" + yaxiskw = "yaxis" + elif s in ["shapes", "annotations", "images"]: + obj = self.layout[s] + xaxiskw = "xref" + yaxiskw = "yref" + else: + obj = None + if obj: + ret |= any( + t == (xref, yref) + for t in [ + # if a object exists but has no xaxis or yaxis keys, then it + # is plotted with xaxis/xref 'x' and yaxis/yref 'y' + ( + "x" if d[xaxiskw] is None else d[xaxiskw], + "y" if d[yaxiskw] is None else d[yaxiskw], + ) + for d in obj + ] ) - for d in self.data - ] - ) + return ret class BasePlotlyType(object): diff --git a/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py b/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py index 5563a419732..63f379bbcd8 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py +++ b/packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py @@ -120,7 +120,43 @@ def test_add_trace_no_exclude_empty_subplots(): fig.add_trace(go.Scatter(x=[1, 2, 3], y=[5, 1, 2]), row=1, col=1) fig.add_trace(go.Scatter(x=[1, 2, 3], y=[2, 1, -7]), row=2, col=2) # Add traces with exclude_empty_subplots set to true and make sure this - # doesn't add to traces that don't already have data + # even adds to traces that don't already have data + fig.add_trace(go.Scatter(x=[1, 2, 3], y=[0, 1, -1]), row="all", col="all") + assert len(fig.data) == 6 + assert fig.data[2]["xaxis"] == "x" and fig.data[2]["yaxis"] == "y" + assert fig.data[3]["xaxis"] == "x2" and fig.data[3]["yaxis"] == "y2" + assert fig.data[4]["xaxis"] == "x3" and fig.data[4]["yaxis"] == "y3" + assert fig.data[5]["xaxis"] == "x4" and fig.data[5]["yaxis"] == "y4" + + +def test_add_trace_exclude_totally_empty_subplots(): + # Add traces + fig = make_subplots(2, 2) + fig.add_trace(go.Scatter(x=[1, 2, 3], y=[5, 1, 2]), row=1, col=1) + fig.add_trace(go.Scatter(x=[1, 2, 3], y=[2, 1, -7]), row=2, col=2) + fig.add_shape(dict(type="rect", x0=0, x1=1, y0=0, y1=1), row=1, col=2) + # Add traces with exclude_empty_subplots set to true and make sure this + # doesn't add to traces that don't already have data or layout objects + fig.add_trace( + go.Scatter(x=[1, 2, 3], y=[0, 1, -1]), + row="all", + col="all", + exclude_empty_subplots=["anything", "truthy"], + ) + assert len(fig.data) == 5 + assert fig.data[2]["xaxis"] == "x" and fig.data[2]["yaxis"] == "y" + assert fig.data[3]["xaxis"] == "x2" and fig.data[3]["yaxis"] == "y2" + assert fig.data[4]["xaxis"] == "x4" and fig.data[4]["yaxis"] == "y4" + + +def test_add_trace_no_exclude_totally_empty_subplots(): + # Add traces + fig = make_subplots(2, 2) + fig.add_trace(go.Scatter(x=[1, 2, 3], y=[5, 1, 2]), row=1, col=1) + fig.add_trace(go.Scatter(x=[1, 2, 3], y=[2, 1, -7]), row=2, col=2) + fig.add_shape(dict(type="rect", x0=0, x1=1, y0=0, y1=1), row=1, col=2) + # Add traces with exclude_empty_subplots set to true and make sure this + # even adds to traces that don't already have data or layout objects fig.add_trace(go.Scatter(x=[1, 2, 3], y=[0, 1, -1]), row="all", col="all") assert len(fig.data) == 6 assert fig.data[2]["xaxis"] == "x" and fig.data[2]["yaxis"] == "y" diff --git a/packages/python/plotly/plotly/tests/test_core/test_subplots/test_find_nonempty_subplots.py b/packages/python/plotly/plotly/tests/test_core/test_subplots/test_find_nonempty_subplots.py new file mode 100644 index 00000000000..27a66b4feea --- /dev/null +++ b/packages/python/plotly/plotly/tests/test_core/test_subplots/test_find_nonempty_subplots.py @@ -0,0 +1,61 @@ +import pytest +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from itertools import combinations, product +from functools import reduce + + +def all_combos(it): + return list( + reduce( + lambda a, b: a + b, + [list(combinations(it, r)) for r in range(1, len(it))], + [], + ) + ) + + +def translate_layout_keys(t): + xr, yr = t + xr = xr.replace("axis", "") + yr = yr.replace("axis", "") + return (xr, yr) + + +def get_non_empty_subplots(fig, selector): + gr = fig._validate_get_grid_ref() + nrows = len(gr) + ncols = len(gr[0]) + sp_addresses = product(range(nrows), range(ncols)) + # assign a number similar to plotly's xref/yref (e.g, xref=x2) to each + # subplot address (xref=x -> 1, but xref=x3 -> 3) + # sp_ax_numbers=range(1,len(sp_addresses)+1) + # Get those subplot numbers which contain something + ret = list( + filter( + lambda sp: fig._subplot_not_empty( + *translate_layout_keys(sp.layout_keys), selector=selector + ), + [gr[r][c][0] for r, c in sp_addresses], + ) + ) + return ret + + +def test_choose_correct_non_empty_subplots(): + # This checks to see that the correct subplots are selected for all + # combinations of selectors + fig = make_subplots(2, 2) + fig.add_trace(go.Scatter(x=[1, 2], y=[3, 4]), row=1, col=1) + fig.add_shape(dict(type="rect", x0=1, x1=2, y0=3, y1=4), row=1, col=2) + fig.add_annotation(dict(text="A", x=1, y=2), row=2, col=1) + fig.add_layout_image( + dict(source="test", x=1, y=2, sizex=0.5, sizey=0.5), row=2, col=2 + ) + all_subplots = get_non_empty_subplots(fig, "all") + selectors = all_combos(["traces", "shapes", "annotations", "images"]) + subplot_combos = all_combos(all_subplots) + assert len(selectors) == len(subplot_combos) + for s, spc in zip(selectors, subplot_combos): + sps = tuple(get_non_empty_subplots(fig, s)) + assert sps == spc diff --git a/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_annotations.py b/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_annotations.py index c08129bb31b..e7ac974d68a 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_annotations.py +++ b/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_annotations.py @@ -270,51 +270,78 @@ def test_image_attributes(self): def test_exclude_empty_subplots(): - for k, fun, d in [ + for k, fun, d, fun2, d2 in [ ( "shapes", go.Figure.add_shape, dict(type="rect", x0=1.5, x1=2.5, y0=3.5, y1=4.5), + # add a different type to make the check easier (otherwise we might + # mix up the objects added before and after fun was run) + go.Figure.add_annotation, + dict(x=1, y=2, text="A"), + ), + ( + "annotations", + go.Figure.add_annotation, + dict(x=1, y=2, text="A"), + go.Figure.add_layout_image, + dict(x=3, y=4, sizex=2, sizey=3, source="test"), ), - ("annotations", go.Figure.add_annotation, dict(x=1, y=2, text="A")), ( "images", go.Figure.add_layout_image, dict(x=3, y=4, sizex=2, sizey=3, source="test"), + go.Figure.add_shape, + dict(type="rect", x0=1.5, x1=2.5, y0=3.5, y1=4.5), ), ]: # make a figure where not all the subplots are populated fig = make_subplots(2, 2) fig.add_trace(go.Scatter(x=[1, 2, 3], y=[5, 1, 2]), row=1, col=1) fig.add_trace(go.Scatter(x=[1, 2, 3], y=[2, 1, -7]), row=2, col=2) + fun2(fig, d2, row=1, col=2) # add a thing to all subplots but make sure it only goes on the - # plots without data - fun(fig, d, row="all", col="all", exclude_empty_subplots=True) - assert len(fig.layout[k]) == 2 + # plots without data or layout objects + fun(fig, d, row="all", col="all", exclude_empty_subplots="anything_truthy") + assert len(fig.layout[k]) == 3 assert fig.layout[k][0]["xref"] == "x" and fig.layout[k][0]["yref"] == "y" - assert fig.layout[k][1]["xref"] == "x4" and fig.layout[k][1]["yref"] == "y4" + assert fig.layout[k][1]["xref"] == "x2" and fig.layout[k][1]["yref"] == "y2" + assert fig.layout[k][2]["xref"] == "x4" and fig.layout[k][2]["yref"] == "y4" def test_no_exclude_empty_subplots(): - for k, fun, d in [ + for k, fun, d, fun2, d2 in [ ( "shapes", go.Figure.add_shape, dict(type="rect", x0=1.5, x1=2.5, y0=3.5, y1=4.5), + # add a different type to make the check easier (otherwise we might + # mix up the objects added before and after fun was run) + go.Figure.add_annotation, + dict(x=1, y=2, text="A"), + ), + ( + "annotations", + go.Figure.add_annotation, + dict(x=1, y=2, text="A"), + go.Figure.add_layout_image, + dict(x=3, y=4, sizex=2, sizey=3, source="test"), ), - ("annotations", go.Figure.add_annotation, dict(x=1, y=2, text="A")), ( "images", go.Figure.add_layout_image, dict(x=3, y=4, sizex=2, sizey=3, source="test"), + go.Figure.add_shape, + dict(type="rect", x0=1.5, x1=2.5, y0=3.5, y1=4.5), ), ]: # make a figure where not all the subplots are populated fig = make_subplots(2, 2) fig.add_trace(go.Scatter(x=[1, 2, 3], y=[5, 1, 2]), row=1, col=1) fig.add_trace(go.Scatter(x=[1, 2, 3], y=[2, 1, -7]), row=2, col=2) - # add a thing to all subplots and make sure it even goes on the - # plots without data + fun2(fig, d2, row=1, col=2) + # add a thing to all subplots but make sure it only goes on the + # plots without data or layout objects fun(fig, d, row="all", col="all", exclude_empty_subplots=False) assert len(fig.layout[k]) == 4 assert fig.layout[k][0]["xref"] == "x" and fig.layout[k][0]["yref"] == "y"