diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py index 3ec382d0f3d..c825233aca2 100644 --- a/packages/python/plotly/plotly/express/__init__.py +++ b/packages/python/plotly/plotly/express/__init__.py @@ -39,6 +39,11 @@ choropleth, density_contour, density_heatmap, + pie, + sunburst, + treemap, + funnel, + funnel_area, ) from ._imshow import imshow @@ -77,6 +82,11 @@ "strip", "histogram", "choropleth", + "pie", + "sunburst", + "treemap", + "funnel", + "funnel_area", "imshow", "data", "colors", diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index cbf2fff85cd..8cbd4d85b65 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -1115,3 +1115,208 @@ def parallel_categories( parallel_categories.__doc__ = make_docstring(parallel_categories) + + +def pie( + data_frame=None, + names=None, + values=None, + color=None, + color_discrete_sequence=None, + color_discrete_map={}, + hover_name=None, + hover_data=None, + custom_data=None, + labels={}, + title=None, + template=None, + width=None, + height=None, + opacity=None, + hole=None, +): + """ + In a pie plot, each row of `data_frame` is represented as a sector of a pie. + """ + if color_discrete_sequence is not None: + layout_patch = {"piecolorway": color_discrete_sequence} + else: + layout_patch = {} + return make_figure( + args=locals(), + constructor=go.Pie, + trace_patch=dict(showlegend=(names is not None), hole=hole), + layout_patch=layout_patch, + ) + + +pie.__doc__ = make_docstring( + pie, + override_dict=dict( + hole=[ + "float", + "Sets the fraction of the radius to cut out of the pie." + "Use this to make a donut chart.", + ], + ), +) + + +def sunburst( + data_frame=None, + names=None, + values=None, + parents=None, + ids=None, + color=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + color_discrete_sequence=None, + color_discrete_map={}, + hover_name=None, + hover_data=None, + custom_data=None, + labels={}, + title=None, + template=None, + width=None, + height=None, + branchvalues=None, + maxdepth=None, +): + """ + A sunburst plot represents hierarchial data as sectors laid out over + several levels of concentric rings. + """ + if color_discrete_sequence is not None: + layout_patch = {"sunburstcolorway": color_discrete_sequence} + else: + layout_patch = {} + return make_figure( + args=locals(), + constructor=go.Sunburst, + trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth), + layout_patch=layout_patch, + ) + + +sunburst.__doc__ = make_docstring(sunburst) + + +def treemap( + data_frame=None, + names=None, + values=None, + parents=None, + ids=None, + color=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + color_discrete_sequence=None, + color_discrete_map={}, + hover_name=None, + hover_data=None, + custom_data=None, + labels={}, + title=None, + template=None, + width=None, + height=None, + branchvalues=None, + maxdepth=None, +): + """ + A treemap plot represents hierarchial data as nested rectangular sectors. + """ + if color_discrete_sequence is not None: + layout_patch = {"treemapcolorway": color_discrete_sequence} + else: + layout_patch = {} + return make_figure( + args=locals(), + constructor=go.Treemap, + trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth), + layout_patch=layout_patch, + ) + + +treemap.__doc__ = make_docstring(treemap) + + +def funnel( + data_frame=None, + x=None, + y=None, + color=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + hover_name=None, + hover_data=None, + custom_data=None, + text=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + opacity=None, + orientation="h", + log_x=False, + log_y=False, + range_x=None, + range_y=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a funnel plot, each row of `data_frame` is represented as a rectangular sector of a funnel. + """ + return make_figure( + args=locals(), + constructor=go.Funnel, + trace_patch=dict(opacity=opacity, orientation=orientation), + ) + + +funnel.__doc__ = make_docstring(funnel) + + +def funnel_area( + data_frame=None, + names=None, + values=None, + color=None, + color_discrete_sequence=None, + color_discrete_map={}, + hover_name=None, + hover_data=None, + custom_data=None, + labels={}, + title=None, + template=None, + width=None, + height=None, + opacity=None, +): + """ + In a funnel area plot, each row of `data_frame` is represented as a trapezoidal sector of a funnel. + """ + if color_discrete_sequence is not None: + layout_patch = {"funnelareacolorway": color_discrete_sequence} + else: + layout_patch = {} + return make_figure( + args=locals(), + constructor=go.Funnelarea, + trace_patch=dict(showlegend=(names is not None)), + layout_patch=layout_patch, + ) + + +funnel_area.__doc__ = make_docstring(funnel_area) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 503edaf7b9c..923e6ea3dfc 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -291,6 +291,28 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): result["z"] = g[v] result["coloraxis"] = "coloraxis1" mapping_labels[v_label] = "%{z}" + elif trace_spec.constructor in [ + go.Sunburst, + go.Treemap, + go.Pie, + go.Funnelarea, + ]: + if "marker" not in result: + result["marker"] = dict() + + if args.get("color_is_continuous"): + result["marker"]["colors"] = g[v] + result["marker"]["coloraxis"] = "coloraxis1" + mapping_labels[v_label] = "%{color}" + else: + result["marker"]["colors"] = [] + mapping = {} + for cat in g[v]: + if mapping.get(cat) is None: + mapping[cat] = args["color_discrete_sequence"][ + len(mapping) % len(args["color_discrete_sequence"]) + ] + result["marker"]["colors"].append(mapping[cat]) else: colorable = "marker" if trace_spec.constructor in [go.Parcats, go.Parcoords]: @@ -305,11 +327,38 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): elif k == "locations": result[k] = g[v] mapping_labels[v_label] = "%{location}" + elif k == "values": + result[k] = g[v] + _label = "value" if v_label == "values" else v_label + mapping_labels[_label] = "%{value}" + elif k == "parents": + result[k] = g[v] + _label = "parent" if v_label == "parents" else v_label + mapping_labels[_label] = "%{parent}" + elif k == "ids": + result[k] = g[v] + _label = "id" if v_label == "ids" else v_label + mapping_labels[_label] = "%{id}" + elif k == "names": + if trace_spec.constructor in [ + go.Sunburst, + go.Treemap, + go.Pie, + go.Funnelarea, + ]: + result["labels"] = g[v] + _label = "label" if v_label == "names" else v_label + mapping_labels[_label] = "%{label}" + else: + result[k] = g[v] else: if v: result[k] = g[v] mapping_labels[v_label] = "%%{%s}" % k - if trace_spec.constructor not in [go.Parcoords, go.Parcats]: + if trace_spec.constructor not in [ + go.Parcoords, + go.Parcats, + ]: hover_lines = [k + "=" + v for k, v in mapping_labels.items()] result["hovertemplate"] = hover_header + "
".join(hover_lines) return result, fit_results @@ -674,6 +723,7 @@ def one_group(x): def apply_default_cascade(args): # first we apply px.defaults to unspecified args + for param in ( ["color_discrete_sequence", "color_continuous_scale"] + ["symbol_sequence", "line_dash_sequence", "template"] @@ -956,6 +1006,7 @@ def infer_config(args, constructor, trace_patch): attrables = ( ["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"] + ["custom_data", "hover_name", "hover_data", "text"] + + ["names", "values", "parents", "ids"] + ["error_x", "error_x_minus"] + ["error_y", "error_y_minus", "error_z", "error_z_minus"] + ["lat", "lon", "locations", "animation_group"] @@ -989,14 +1040,34 @@ def infer_config(args, constructor, trace_patch): and args["data_frame"][args["color"]].dtype.kind in "bifc" ): attrs.append("color") + args["color_is_continuous"] = True + elif constructor in [go.Sunburst, go.Treemap]: + attrs.append("color") + args["color_is_continuous"] = False else: grouped_attrs.append("marker.color") elif "line_group" in args or constructor == go.Histogram2dContour: grouped_attrs.append("line.color") + elif constructor in [go.Pie, go.Funnelarea]: + attrs.append("color") + if args["color"]: + if args["hover_data"] is None: + args["hover_data"] = [] + args["hover_data"].append(args["color"]) else: grouped_attrs.append("marker.color") - show_colorbar = bool("color" in attrs and args["color"]) + show_colorbar = bool( + "color" in attrs + and args["color"] + and constructor not in [go.Pie, go.Funnelarea] + and ( + constructor not in [go.Treemap, go.Sunburst] + or args.get("color_is_continuous") + ) + ) + else: + show_colorbar = False # Compute line_dash grouping attribute if "line_dash" in args: @@ -1148,6 +1219,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): go.Parcoords, go.Choropleth, go.Histogram2d, + go.Sunburst, + go.Treemap, ]: trace.update( legendgroup=trace_name, diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 950a3953be7..8948e6b321d 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -66,6 +66,21 @@ colref_desc, "Values from this column or array_like are used to position marks along the angular axis in polar coordinates.", ], + values=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to set values associated to sectors.", + ], + parents=[ + colref_type, + colref_desc, + "Values from this column or array_like are used as parents in sunburst and treemap charts.", + ], + ids=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to set ids of sectors", + ], lat=[ colref_type, colref_desc, @@ -168,6 +183,11 @@ colref_desc, "Values from this column or array_like appear in the figure as text labels.", ], + names=[ + colref_type, + colref_desc, + "Values from this column or array_like are used as labels for sectors.", + ], locationmode=[ "str", "One of 'ISO-3', 'USA-states', or 'country names'", @@ -442,21 +462,41 @@ nbins=["int", "Positive integer.", "Sets the number of bins."], nbinsx=["int", "Positive integer.", "Sets the number of bins along the x axis."], nbinsy=["int", "Positive integer.", "Sets the number of bins along the y axis."], + branchvalues=[ + "str", + "'total' or 'remainder'", + "Determines how the items in `values` are summed. When" + "set to 'total', items in `values` are taken to be value" + "of all its descendants. When set to 'remainder', items" + "in `values` corresponding to the root and the branches" + ":sectors are taken to be the extra part not part of the" + "sum of the values at their leaves.", + ], + maxdepth=[ + "int", + "Positive integer", + "Sets the number of rendered sectors from any given `level`. Set `maxdepth` to -1 to render all the" + "levels in the hierarchy.", + ], ) -def make_docstring(fn): +def make_docstring(fn, override_dict={}): tw = TextWrapper(width=77, initial_indent=" ", subsequent_indent=" ") result = (fn.__doc__ or "") + "\nParameters\n----------\n" for param in inspect.getargspec(fn)[0]: - param_desc_list = docs[param][1:] + if override_dict.get(param): + param_doc = override_dict[param] + else: + param_doc = docs[param] + param_desc_list = param_doc[1:] param_desc = ( tw.fill(" ".join(param_desc_list or "")) if param in docs else "(documentation missing from map)" ) - param_type = docs[param][0] + param_type = param_doc[0] result += "%s: %s\n%s\n" % (param, param_type, param_desc) result += "\nReturns\n-------\n" result += " A `Figure` object." diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py new file mode 100644 index 00000000000..339accf9d57 --- /dev/null +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -0,0 +1,141 @@ +import plotly.express as px +import plotly.graph_objects as go +from numpy.testing import assert_array_equal +import numpy as np + + +def _compare_figures(go_trace, px_fig): + """Compare a figure created with a go trace and a figure created with + a px function call. Check that all values inside the go Figure are the + same in the px figure (which sets more parameters). + """ + go_fig = go.Figure(go_trace) + go_fig = go_fig.to_plotly_json() + px_fig = px_fig.to_plotly_json() + del go_fig["layout"]["template"] + del px_fig["layout"]["template"] + for key in go_fig["data"][0]: + assert_array_equal(go_fig["data"][0][key], px_fig["data"][0][key]) + for key in go_fig["layout"]: + assert go_fig["layout"][key] == px_fig["layout"][key] + + +def test_pie_like_px(): + # Pie + labels = ["Oxygen", "Hydrogen", "Carbon_Dioxide", "Nitrogen"] + values = [4500, 2500, 1053, 500] + + fig = px.pie(names=labels, values=values) + trace = go.Pie(labels=labels, values=values) + _compare_figures(trace, fig) + + labels = ["Eve", "Cain", "Seth", "Enos", "Noam", "Abel", "Awan", "Enoch", "Azura"] + parents = ["", "Eve", "Eve", "Seth", "Seth", "Eve", "Eve", "Awan", "Eve"] + values = [10, 14, 12, 10, 2, 6, 6, 4, 4] + # Sunburst + fig = px.sunburst(names=labels, parents=parents, values=values) + trace = go.Sunburst(labels=labels, parents=parents, values=values) + _compare_figures(trace, fig) + # Treemap + fig = px.treemap(names=labels, parents=parents, values=values) + trace = go.Treemap(labels=labels, parents=parents, values=values) + _compare_figures(trace, fig) + + # Funnel + x = ["A", "B", "C"] + y = [3, 2, 1] + fig = px.funnel(y=y, x=x) + trace = go.Funnel(y=y, x=x) + _compare_figures(trace, fig) + # Funnelarea + fig = px.funnel_area(values=y, names=x) + trace = go.Funnelarea(values=y, labels=x) + _compare_figures(trace, fig) + + +def test_sunburst_treemap_colorscales(): + labels = ["Eve", "Cain", "Seth", "Enos", "Noam", "Abel", "Awan", "Enoch", "Azura"] + parents = ["", "Eve", "Eve", "Seth", "Seth", "Eve", "Eve", "Awan", "Eve"] + values = [10, 14, 12, 10, 2, 6, 6, 4, 4] + for func, colorway in zip( + [px.sunburst, px.treemap], ["sunburstcolorway", "treemapcolorway"] + ): + # Continuous colorscale + fig = func( + names=labels, + parents=parents, + values=values, + color=values, + color_continuous_scale="Viridis", + range_color=(5, 15), + ) + assert fig.layout.coloraxis.cmin, fig.layout.coloraxis.cmax == (5, 15) + # Discrete colorscale, color arg passed + color_seq = px.colors.sequential.Reds + fig = func( + names=labels, + parents=parents, + values=values, + color=labels, + color_discrete_sequence=color_seq, + ) + assert np.all([col in color_seq for col in fig.data[0].marker.colors]) + # Numerical color arg passed, fall back to continuous + fig = func(names=labels, parents=parents, values=values, color=values,) + assert [ + el[0] == px.colors.sequential.Viridis + for i, el in enumerate(fig.layout.coloraxis.colorscale) + ] + # Numerical color arg passed, continuous colorscale + # even if color_discrete_sequence if passed + fig = func( + names=labels, + parents=parents, + values=values, + color=values, + color_discrete_sequence=color_seq, + ) + assert [ + el[0] == px.colors.sequential.Viridis + for i, el in enumerate(fig.layout.coloraxis.colorscale) + ] + + # Discrete colorscale, no color arg passed + color_seq = px.colors.sequential.Reds + fig = func( + names=labels, + parents=parents, + values=values, + color_discrete_sequence=color_seq, + ) + assert list(fig.layout[colorway]) == color_seq + + +def test_pie_funnelarea_colorscale(): + labels = ["A", "B", "C", "D"] + values = [3, 2, 1, 4] + for func, colorway in zip( + [px.sunburst, px.treemap], ["sunburstcolorway", "treemapcolorway"] + ): + # Discrete colorscale, no color arg passed + color_seq = px.colors.sequential.Reds + fig = func(names=labels, values=values, color_discrete_sequence=color_seq,) + assert list(fig.layout[colorway]) == color_seq + # Discrete colorscale, color arg passed + color_seq = px.colors.sequential.Reds + fig = func( + names=labels, + values=values, + color=labels, + color_discrete_sequence=color_seq, + ) + assert np.all([col in color_seq for col in fig.data[0].marker.colors]) + + +def test_funnel(): + fig = px.funnel( + x=[5, 4, 3, 3, 2, 1], + y=["A", "B", "C", "A", "B", "C"], + color=["0", "0", "0", "1", "1", "1"], + ) + assert len(fig.data) == 2