From 4ac1efb7d07fbd0e13bdcc9d815b002dbe0ab38c Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 13 Dec 2019 14:14:41 -0500 Subject: [PATCH 01/33] proof of concept --- .../plotly/plotly/express/_chart_types.py | 26 +++++++++++++++++++ packages/python/plotly/plotly/express/_doc.py | 4 +++ 2 files changed, 30 insertions(+) diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 9be0a02c035..1ff0fb559a4 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -1,5 +1,6 @@ from ._core import make_figure from ._doc import make_docstring +from .preprocess import preprocess_sunburst_treemap import plotly.graph_objs as go @@ -1252,6 +1253,7 @@ def sunburst( names=None, values=None, parents=None, + path=None, ids=None, color=None, color_continuous_scale=None, @@ -1278,6 +1280,17 @@ def sunburst( layout_patch = {"sunburstcolorway": color_discrete_sequence} else: layout_patch = {} + if path is not None and (ids is not None or parents is not None): + raise ValueError( + "Either `path` should be provided, or `ids` and `parents`." + "These parameters are mutually exclusive and cannot be passed together." + ) + if path is not None: + data_frame = preprocess_sunburst_treemap(data_frame, path, values) + path = None + ids = 'labels' + names = 'labels' + parents = 'parent' return make_figure( args=locals(), constructor=go.Sunburst, @@ -1295,6 +1308,7 @@ def treemap( values=None, parents=None, ids=None, + path=None, color=None, color_continuous_scale=None, range_color=None, @@ -1320,6 +1334,18 @@ def treemap( layout_patch = {"treemapcolorway": color_discrete_sequence} else: layout_patch = {} + if path is not None and (ids is not None or parents is not None): + raise ValueError( + "Either `path` should be provided, or `ids` and `parents`." + "These parameters are mutually exclusive and cannot be passed together." + ) + if path is not None: + data_frame = preprocess_sunburst_treemap(data_frame, path, values) + path = None + ids = 'labels' + names = 'labels' + parents = 'parent' + return make_figure( args=locals(), constructor=go.Treemap, diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 3ee3df4e5a7..a5f0232bdad 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -82,6 +82,10 @@ colref_desc, "Values from this column or array_like are used to set ids of sectors", ], + path=[ + colref_type, + colref_desc + ], lat=[ colref_type, colref_desc, From c619e2858c74216644ecdb3c36d983635c04b7a8 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 16 Dec 2019 14:12:24 -0500 Subject: [PATCH 02/33] first version --- .../plotly/plotly/express/_chart_types.py | 16 ++++-- .../python/plotly/plotly/express/_core.py | 57 ++++++++++++++++++- .../plotly/plotly/express/preprocess.py | 35 ++++++++++++ 3 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 packages/python/plotly/plotly/express/preprocess.py diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 1ff0fb559a4..841927d643a 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -1285,12 +1285,16 @@ def sunburst( "Either `path` should be provided, or `ids` and `parents`." "These parameters are mutually exclusive and cannot be passed together." ) + """ if path is not None: - data_frame = preprocess_sunburst_treemap(data_frame, path, values) + data_frame = preprocess_sunburst_treemap(data_frame, path, values, + #color, + other_columns=hover_data) path = None - ids = 'labels' + ids = 'id' names = 'labels' parents = 'parent' + """ return make_figure( args=locals(), constructor=go.Sunburst, @@ -1339,12 +1343,16 @@ def treemap( "Either `path` should be provided, or `ids` and `parents`." "These parameters are mutually exclusive and cannot be passed together." ) + """ if path is not None: - data_frame = preprocess_sunburst_treemap(data_frame, path, values) + data_frame = preprocess_sunburst_treemap(data_frame, path, values, + #color, + other_columns=hover_data) path = None - ids = 'labels' + ids = 'id' names = 'labels' parents = 'parent' + """ return make_figure( args=locals(), diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 3911641f06e..cb74e2fc0e1 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -890,6 +890,8 @@ def build_dataframe(args, attrables, array_attrables): else: df_output[df_input.columns] = df_input[df_input.columns] + if 'path' in args and args['path'] is not None: + df_output[args['path']] = df_input[args['path']] # Loop over possible arguments for field_name in attrables: # Massaging variables @@ -1007,6 +1009,57 @@ def build_dataframe(args, attrables, array_attrables): return args +def process_dataframe_hierarchy(args): + """ + Build dataframe for sunburst or treemap when the path argument is provided. + """ + df = args['data_frame'] + path = args['path'] + # Other columns (for color, hover_data, custom_data etc.) + cols = list(set(df.columns).difference(path)) + df_all_trees = pd.DataFrame(columns=['labels', 'parent', 'id'] + cols) + for col in cols: + df_all_trees[col] = df_all_trees[col].astype(df[col].dtype) + for i, level in enumerate(path): + df_tree = pd.DataFrame(columns=df_all_trees.columns) + dfg = df.groupby(path[i:]).sum(numerical_only=True) + dfg = dfg.reset_index() + df_tree['labels'] = dfg[level].copy().astype(str) + df_tree['parent'] = '' + df_tree['id'] = dfg[level].copy().astype(str) + if i < len(path) - 1: + j = i + 1 + while j < len(path): + df_tree['parent'] += dfg[path[j]].copy().astype(str) + df_tree['id'] += dfg[path[j]].copy().astype(str) + j += 1 + else: + df_tree['parent'] = 'total' + + if i == 0 and cols: + df_tree[cols] = dfg[cols] + elif cols: + for col in cols: + df_tree[col] = np.nan + df_tree[args['values']] = dfg[args['values']] + df_all_trees = df_all_trees.append(df_tree, ignore_index=True) + total_dict = {'labels': 'total', 'id': 'total', 'parent': '', + args['values']:df[args['values']].sum(), + } + for col in cols: + if not col == args['values']: + total_dict[col] = np.nan + total = pd.Series(total_dict) + df_all_trees = df_all_trees.append(total, ignore_index=True) + args['data_frame'] = df_all_trees + args['path'] = None + args['ids'] = 'id' + args['names'] = 'labels' + args['parents'] = 'parent' + return args + + + def infer_config(args, constructor, trace_patch): # Declare all supported attributes, across all plot types attrables = ( @@ -1017,7 +1070,7 @@ def infer_config(args, constructor, trace_patch): + ["error_y", "error_y_minus", "error_z", "error_z_minus"] + ["lat", "lon", "locations", "animation_group"] ) - array_attrables = ["dimensions", "custom_data", "hover_data"] + array_attrables = ["dimensions", "custom_data", "hover_data", "path"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] all_attrables = attrables + group_attrables + ["color"] group_attrs = ["symbol", "line_dash"] @@ -1026,6 +1079,8 @@ def infer_config(args, constructor, trace_patch): all_attrables += [group_attr] args = build_dataframe(args, all_attrables, array_attrables) + if constructor in [go.Treemap, go.Sunburst] and args['path'] is not None: + args = process_dataframe_hierarchy(args) attrs = [k for k in attrables if k in args] grouped_attrs = [] diff --git a/packages/python/plotly/plotly/express/preprocess.py b/packages/python/plotly/plotly/express/preprocess.py new file mode 100644 index 00000000000..fbd07116763 --- /dev/null +++ b/packages/python/plotly/plotly/express/preprocess.py @@ -0,0 +1,35 @@ +import pandas as pd + +def preprocess_sunburst_treemap(df, path, value_column, other_columns=None): + df_all_trees = pd.DataFrame(columns=['labels', 'parent', 'id']) + if isinstance(path, list): + for i, level in enumerate(path): + df_tree = pd.DataFrame(columns=['labels', 'parent', 'id']) + dfg = df.groupby(path[i:]).sum(numerical_only=True) + dfg = dfg.reset_index() + df_tree['labels'] = dfg[level].copy().astype(str) + df_tree['parent'] = '' + df_tree['id'] = dfg[level].copy().astype(str) + if i < len(path) - 1: + j = i + 1 + while j < len(path): + df_tree['parent'] += dfg[path[j]].copy().astype(str) + df_tree['id'] += dfg[path[j]].copy().astype(str) + j += 1 + else: + df_tree['parent'] = 'total' + + if i == 0 and other_columns: + df_tree[other_columns] = dfg[other_columns] + elif other_columns: + for col in other_columns: + df_tree[col] = '' + df_tree[value_column] = dfg[value_column] + #df_tree[color_column] = dfg[color_column] + df_all_trees = df_all_trees.append(df_tree, ignore_index=True) + total = pd.Series({'labels': 'total', 'id': 'total', 'parent': '', + value_column:df[value_column].sum(), + #color_column:df[color_column].sum(), + }) + df_all_trees = df_all_trees.append(total, ignore_index=True) + return df_all_trees From 10668b6774d29a46f17f947368043d3cec808f25 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 16 Dec 2019 15:34:16 -0500 Subject: [PATCH 03/33] tests --- .../plotly/plotly/express/_chart_types.py | 23 ++---------- .../python/plotly/plotly/express/_core.py | 16 ++++++--- .../plotly/plotly/express/preprocess.py | 35 ------------------- .../test_core/test_px/test_px_functions.py | 25 +++++++++++++ 4 files changed, 39 insertions(+), 60 deletions(-) delete mode 100644 packages/python/plotly/plotly/express/preprocess.py diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 841927d643a..311e05d4c16 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -1,6 +1,5 @@ from ._core import make_figure from ._doc import make_docstring -from .preprocess import preprocess_sunburst_treemap import plotly.graph_objs as go @@ -1285,16 +1284,8 @@ def sunburst( "Either `path` should be provided, or `ids` and `parents`." "These parameters are mutually exclusive and cannot be passed together." ) - """ - if path is not None: - data_frame = preprocess_sunburst_treemap(data_frame, path, values, - #color, - other_columns=hover_data) - path = None - ids = 'id' - names = 'labels' - parents = 'parent' - """ + if path is not None and branchvalues is None: + branchvalues='total' return make_figure( args=locals(), constructor=go.Sunburst, @@ -1343,16 +1334,6 @@ def treemap( "Either `path` should be provided, or `ids` and `parents`." "These parameters are mutually exclusive and cannot be passed together." ) - """ - if path is not None: - data_frame = preprocess_sunburst_treemap(data_frame, path, values, - #color, - other_columns=hover_data) - path = None - ids = 'id' - names = 'labels' - parents = 'parent' - """ return make_figure( args=locals(), diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index cb74e2fc0e1..b7b71886c23 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -890,6 +890,7 @@ def build_dataframe(args, attrables, array_attrables): else: df_output[df_input.columns] = df_input[df_input.columns] + # This should be improved + tested - HACK if 'path' in args and args['path'] is not None: df_output[args['path']] = df_input[args['path']] # Loop over possible arguments @@ -1018,6 +1019,7 @@ def process_dataframe_hierarchy(args): # Other columns (for color, hover_data, custom_data etc.) cols = list(set(df.columns).difference(path)) df_all_trees = pd.DataFrame(columns=['labels', 'parent', 'id'] + cols) + # Set column type here (useful for continuous vs discrete colorscale) for col in cols: df_all_trees[col] = df_all_trees[col].astype(df[col].dtype) for i, level in enumerate(path): @@ -1040,17 +1042,23 @@ def process_dataframe_hierarchy(args): df_tree[cols] = dfg[cols] elif cols: for col in cols: - df_tree[col] = np.nan - df_tree[args['values']] = dfg[args['values']] + df_tree[col] = 'n/a' + if args['values']: + df_tree[args['values']] = dfg[args['values']] df_all_trees = df_all_trees.append(df_tree, ignore_index=True) + + # Root node total_dict = {'labels': 'total', 'id': 'total', 'parent': '', - args['values']:df[args['values']].sum(), } for col in cols: if not col == args['values']: - total_dict[col] = np.nan + total_dict[col] = 'n/a' + if col == args['values']: + total_dict[col] = df[col].sum() total = pd.Series(total_dict) + df_all_trees = df_all_trees.append(total, ignore_index=True) + # Now modify arguments args['data_frame'] = df_all_trees args['path'] = None args['ids'] = 'id' diff --git a/packages/python/plotly/plotly/express/preprocess.py b/packages/python/plotly/plotly/express/preprocess.py deleted file mode 100644 index fbd07116763..00000000000 --- a/packages/python/plotly/plotly/express/preprocess.py +++ /dev/null @@ -1,35 +0,0 @@ -import pandas as pd - -def preprocess_sunburst_treemap(df, path, value_column, other_columns=None): - df_all_trees = pd.DataFrame(columns=['labels', 'parent', 'id']) - if isinstance(path, list): - for i, level in enumerate(path): - df_tree = pd.DataFrame(columns=['labels', 'parent', 'id']) - dfg = df.groupby(path[i:]).sum(numerical_only=True) - dfg = dfg.reset_index() - df_tree['labels'] = dfg[level].copy().astype(str) - df_tree['parent'] = '' - df_tree['id'] = dfg[level].copy().astype(str) - if i < len(path) - 1: - j = i + 1 - while j < len(path): - df_tree['parent'] += dfg[path[j]].copy().astype(str) - df_tree['id'] += dfg[path[j]].copy().astype(str) - j += 1 - else: - df_tree['parent'] = 'total' - - if i == 0 and other_columns: - df_tree[other_columns] = dfg[other_columns] - elif other_columns: - for col in other_columns: - df_tree[col] = '' - df_tree[value_column] = dfg[value_column] - #df_tree[color_column] = dfg[color_column] - df_all_trees = df_all_trees.append(df_tree, ignore_index=True) - total = pd.Series({'labels': 'total', 'id': 'total', 'parent': '', - value_column:df[value_column].sum(), - #color_column:df[color_column].sum(), - }) - df_all_trees = df_all_trees.append(total, ignore_index=True) - return df_all_trees diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 339accf9d57..0a3efbb9972 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -2,6 +2,7 @@ import plotly.graph_objects as go from numpy.testing import assert_array_equal import numpy as np +import pandas as pd def _compare_figures(go_trace, px_fig): @@ -111,6 +112,30 @@ def test_sunburst_treemap_colorscales(): assert list(fig.layout[colorway]) == color_seq +def test_sunburst_treemap_with_path(): + vendors = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] + sectors = ['Tech', 'Tech', 'Finance', 'Finance', 'Tech', 'Tech', 'Finance', 'Finance'] + regions = ['North', 'North', 'North', 'North', 'South', 'South', 'South', 'South'] + values = [1, 3, 2, 4, 2, 2, 1, 4] + df = pd.DataFrame(dict(vendors=vendors, sectors=sectors, regions=regions, values=values)) + # No values + fig = px.sunburst(df, path=['vendors', 'sectors', 'regions']) + assert fig.data[0].branchvalues == 'total' + # Values passed + fig = px.sunburst(df, path=['vendors', 'sectors', 'regions'], values='values') + assert fig.data[0].branchvalues == 'total' + assert fig.data[0].values[-1] == np.sum(values) + # Values passed + fig = px.sunburst(df, path=['vendors', 'sectors', 'regions'], + values='values') + assert fig.data[0].branchvalues == 'total' + assert fig.data[0].values[-1] == np.sum(values) + # Continuous colorscale + fig = px.sunburst(df, path=['vendors', 'sectors', 'regions'], + values='values', color='values') + assert 'coloraxis' in fig.data[0].marker + assert np.all(np.array(fig.data[0].marker.colors) == np.array(fig.data[0].values)) + def test_pie_funnelarea_colorscale(): labels = ["A", "B", "C", "D"] values = [3, 2, 1, 4] From edfcced8d6c004c26198fd3eb00df50f23852819 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 16 Dec 2019 15:36:18 -0500 Subject: [PATCH 04/33] black --- .../plotly/plotly/express/_chart_types.py | 14 ++--- .../python/plotly/plotly/express/_core.py | 54 ++++++++++--------- packages/python/plotly/plotly/express/_doc.py | 5 +- .../test_core/test_px/test_px_functions.py | 40 +++++++++----- 4 files changed, 62 insertions(+), 51 deletions(-) diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 311e05d4c16..e47eb7682da 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -1281,11 +1281,11 @@ def sunburst( layout_patch = {} if path is not None and (ids is not None or parents is not None): raise ValueError( - "Either `path` should be provided, or `ids` and `parents`." - "These parameters are mutually exclusive and cannot be passed together." - ) + "Either `path` should be provided, or `ids` and `parents`." + "These parameters are mutually exclusive and cannot be passed together." + ) if path is not None and branchvalues is None: - branchvalues='total' + branchvalues = "total" return make_figure( args=locals(), constructor=go.Sunburst, @@ -1331,9 +1331,9 @@ def treemap( layout_patch = {} if path is not None and (ids is not None or parents is not None): raise ValueError( - "Either `path` should be provided, or `ids` and `parents`." - "These parameters are mutually exclusive and cannot be passed together." - ) + "Either `path` should be provided, or `ids` and `parents`." + "These parameters are mutually exclusive and cannot be passed together." + ) return make_figure( args=locals(), diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index b7b71886c23..d09e8821dbe 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -891,8 +891,8 @@ def build_dataframe(args, attrables, array_attrables): df_output[df_input.columns] = df_input[df_input.columns] # This should be improved + tested - HACK - if 'path' in args and args['path'] is not None: - df_output[args['path']] = df_input[args['path']] + if "path" in args and args["path"] is not None: + df_output[args["path"]] = df_input[args["path"]] # Loop over possible arguments for field_name in attrables: # Massaging variables @@ -1014,11 +1014,11 @@ def process_dataframe_hierarchy(args): """ Build dataframe for sunburst or treemap when the path argument is provided. """ - df = args['data_frame'] - path = args['path'] + df = args["data_frame"] + path = args["path"] # Other columns (for color, hover_data, custom_data etc.) cols = list(set(df.columns).difference(path)) - df_all_trees = pd.DataFrame(columns=['labels', 'parent', 'id'] + cols) + df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols) # Set column type here (useful for continuous vs discrete colorscale) for col in cols: df_all_trees[col] = df_all_trees[col].astype(df[col].dtype) @@ -1026,48 +1026,50 @@ def process_dataframe_hierarchy(args): df_tree = pd.DataFrame(columns=df_all_trees.columns) dfg = df.groupby(path[i:]).sum(numerical_only=True) dfg = dfg.reset_index() - df_tree['labels'] = dfg[level].copy().astype(str) - df_tree['parent'] = '' - df_tree['id'] = dfg[level].copy().astype(str) + df_tree["labels"] = dfg[level].copy().astype(str) + df_tree["parent"] = "" + df_tree["id"] = dfg[level].copy().astype(str) if i < len(path) - 1: j = i + 1 while j < len(path): - df_tree['parent'] += dfg[path[j]].copy().astype(str) - df_tree['id'] += dfg[path[j]].copy().astype(str) + df_tree["parent"] += dfg[path[j]].copy().astype(str) + df_tree["id"] += dfg[path[j]].copy().astype(str) j += 1 else: - df_tree['parent'] = 'total' + df_tree["parent"] = "total" if i == 0 and cols: df_tree[cols] = dfg[cols] elif cols: for col in cols: - df_tree[col] = 'n/a' - if args['values']: - df_tree[args['values']] = dfg[args['values']] + df_tree[col] = "n/a" + if args["values"]: + df_tree[args["values"]] = dfg[args["values"]] df_all_trees = df_all_trees.append(df_tree, ignore_index=True) # Root node - total_dict = {'labels': 'total', 'id': 'total', 'parent': '', - } + total_dict = { + "labels": "total", + "id": "total", + "parent": "", + } for col in cols: - if not col == args['values']: - total_dict[col] = 'n/a' - if col == args['values']: + if not col == args["values"]: + total_dict[col] = "n/a" + if col == args["values"]: total_dict[col] = df[col].sum() total = pd.Series(total_dict) df_all_trees = df_all_trees.append(total, ignore_index=True) # Now modify arguments - args['data_frame'] = df_all_trees - args['path'] = None - args['ids'] = 'id' - args['names'] = 'labels' - args['parents'] = 'parent' + args["data_frame"] = df_all_trees + args["path"] = None + args["ids"] = "id" + args["names"] = "labels" + args["parents"] = "parent" return args - def infer_config(args, constructor, trace_patch): # Declare all supported attributes, across all plot types attrables = ( @@ -1087,7 +1089,7 @@ def infer_config(args, constructor, trace_patch): all_attrables += [group_attr] args = build_dataframe(args, all_attrables, array_attrables) - if constructor in [go.Treemap, go.Sunburst] and args['path'] is not None: + if constructor in [go.Treemap, go.Sunburst] and args["path"] is not None: args = process_dataframe_hierarchy(args) attrs = [k for k in attrables if k in args] diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index a5f0232bdad..0f8ee635e1d 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -82,10 +82,7 @@ colref_desc, "Values from this column or array_like are used to set ids of sectors", ], - path=[ - colref_type, - colref_desc - ], + path=[colref_type, colref_desc], lat=[ colref_type, colref_desc, diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 0a3efbb9972..5e64fdf952c 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -113,29 +113,41 @@ def test_sunburst_treemap_colorscales(): def test_sunburst_treemap_with_path(): - vendors = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] - sectors = ['Tech', 'Tech', 'Finance', 'Finance', 'Tech', 'Tech', 'Finance', 'Finance'] - regions = ['North', 'North', 'North', 'North', 'South', 'South', 'South', 'South'] + vendors = ["A", "B", "C", "D", "E", "F", "G", "H"] + sectors = [ + "Tech", + "Tech", + "Finance", + "Finance", + "Tech", + "Tech", + "Finance", + "Finance", + ] + regions = ["North", "North", "North", "North", "South", "South", "South", "South"] values = [1, 3, 2, 4, 2, 2, 1, 4] - df = pd.DataFrame(dict(vendors=vendors, sectors=sectors, regions=regions, values=values)) + df = pd.DataFrame( + dict(vendors=vendors, sectors=sectors, regions=regions, values=values) + ) # No values - fig = px.sunburst(df, path=['vendors', 'sectors', 'regions']) - assert fig.data[0].branchvalues == 'total' + fig = px.sunburst(df, path=["vendors", "sectors", "regions"]) + assert fig.data[0].branchvalues == "total" # Values passed - fig = px.sunburst(df, path=['vendors', 'sectors', 'regions'], values='values') - assert fig.data[0].branchvalues == 'total' + fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values") + assert fig.data[0].branchvalues == "total" assert fig.data[0].values[-1] == np.sum(values) # Values passed - fig = px.sunburst(df, path=['vendors', 'sectors', 'regions'], - values='values') - assert fig.data[0].branchvalues == 'total' + fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values") + assert fig.data[0].branchvalues == "total" assert fig.data[0].values[-1] == np.sum(values) # Continuous colorscale - fig = px.sunburst(df, path=['vendors', 'sectors', 'regions'], - values='values', color='values') - assert 'coloraxis' in fig.data[0].marker + fig = px.sunburst( + df, path=["vendors", "sectors", "regions"], values="values", color="values" + ) + assert "coloraxis" in fig.data[0].marker assert np.all(np.array(fig.data[0].marker.colors) == np.array(fig.data[0].values)) + def test_pie_funnelarea_colorscale(): labels = ["A", "B", "C", "D"] values = [3, 2, 1, 4] From 1f3b8daa0f58d7b1fb7197b38c08fd465cc27863 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 18 Dec 2019 13:56:07 -0500 Subject: [PATCH 05/33] added test with missing values --- .../test_core/test_px/test_px_functions.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 5e64fdf952c..8a58e2591a9 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -146,6 +146,32 @@ def test_sunburst_treemap_with_path(): ) assert "coloraxis" in fig.data[0].marker assert np.all(np.array(fig.data[0].marker.colors) == np.array(fig.data[0].values)) + # Values columns passed as object dtype + df['values'] = df['values'].astype(object) + fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values") + + +def test_sunburst_treemap_with_path_non_rectangular(): + vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None] + sectors = [ + "Tech", + "Tech", + "Finance", + "Finance", + None, + "Tech", + "Tech", + "Finance", + "Finance", + "Finance", + ] + regions = ["North", "North", "North", "North", "North", "South", "South", "South", "South", "South"] + values = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] + df = pd.DataFrame( + dict(vendors=vendors, sectors=sectors, regions=regions, values=values) + ) + fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values") + assert fig.data[0].values[-1] == np.sum(values) def test_pie_funnelarea_colorscale(): From 8cb9d999a942cfc3ef8ab83986ef63294f40bd79 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 18 Dec 2019 14:46:16 -0500 Subject: [PATCH 06/33] examples for sunburst tutorial --- doc/python/sunburst-charts.md | 32 +++++++++++++++++++ .../python/plotly/plotly/express/_core.py | 3 +- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/doc/python/sunburst-charts.md b/doc/python/sunburst-charts.md index a80d530fc08..81b57086aca 100644 --- a/doc/python/sunburst-charts.md +++ b/doc/python/sunburst-charts.md @@ -62,6 +62,38 @@ fig =px.sunburst( fig.show() ``` +### Sunburst of a rectangular DataFrame with plotly.express + +Hierarchical data are often stored as a rectangular dataframe, with different columns corresponding to different levels of the hierarchy. `px.sunburst` can take a `path` parameter corresponding to a list of columns. Note that `id` and `parent` should not be provided if `path` is given. + +```python +import plotly.express as px +df = px.data.tips() +fig = px.sunburst(df, path=['sex', 'time', 'day'], values='total_bill') +fig.show() +``` + +### Rectangular data with missing values + +If the dataset is not fully rectangular, missing values should be supplied as `None`. + +```python +import plotly.express as px +import pandas as pd +vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None] +sectors = ["Tech", "Tech", "Finance", "Finance", None, + "Tech", "Tech", "Finance", "Finance", "Finance"] +regions = ["North", "North", "North", "North", "North", + "South", "South", "South", "South", "South"] +sales = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] +df = pd.DataFrame( + dict(vendors=vendors, sectors=sectors, regions=regions, sales=sales) +) +print(df) +fig = px.sunburst(df, path=['vendors', 'sectors', 'regions'], values='sales') +fig.show() +``` + ### Basic Sunburst Plot with go.Sunburst If Plotly Express does not provide a good starting point, it is also possible to use the more generic `go.Sunburst` function from `plotly.graph_objects`. diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d09e8821dbe..0879f44d724 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1044,7 +1044,8 @@ def process_dataframe_hierarchy(args): for col in cols: df_tree[col] = "n/a" if args["values"]: - df_tree[args["values"]] = dfg[args["values"]] + # EPS hack, to be removed + df_tree[args["values"]] = dfg[args["values"]] - 1.e-10 df_all_trees = df_all_trees.append(df_tree, ignore_index=True) # Root node From cd500a57abcebc14c985f67ebeefe7bcd2e77cf5 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 18 Dec 2019 15:10:44 -0500 Subject: [PATCH 07/33] added type check and corresponding test --- .../python/plotly/plotly/express/_core.py | 10 ++++++++- .../test_core/test_px/test_px_functions.py | 22 +++++++++++++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 0879f44d724..21a54068471 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1016,6 +1016,14 @@ def process_dataframe_hierarchy(args): """ df = args["data_frame"] path = args["path"] + if args["values"]: + try: + df["values"] = pd.to_numeric(df["values"]) + except ValueError: + raise ValueError( + "Column `%s` of `df` could not be converted to a numerical data type." + % args["values"] + ) # Other columns (for color, hover_data, custom_data etc.) cols = list(set(df.columns).difference(path)) df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols) @@ -1045,7 +1053,7 @@ def process_dataframe_hierarchy(args): df_tree[col] = "n/a" if args["values"]: # EPS hack, to be removed - df_tree[args["values"]] = dfg[args["values"]] - 1.e-10 + df_tree[args["values"]] = dfg[args["values"]] - 1.0e-10 df_all_trees = df_all_trees.append(df_tree, ignore_index=True) # Root node diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 8a58e2591a9..749da4e572c 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -3,6 +3,7 @@ from numpy.testing import assert_array_equal import numpy as np import pandas as pd +import pytest def _compare_figures(go_trace, px_fig): @@ -146,9 +147,11 @@ def test_sunburst_treemap_with_path(): ) assert "coloraxis" in fig.data[0].marker assert np.all(np.array(fig.data[0].marker.colors) == np.array(fig.data[0].values)) - # Values columns passed as object dtype - df['values'] = df['values'].astype(object) - fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values") + # Error when values cannot be converted to numerical data type + df["values"] = ["1 000", "3 000", "2", "4", "2", "2", "1 000", "4 000"] + msg = "Column `values` of `df` could not be converted to a numerical data type." + with pytest.raises(ValueError, match=msg): + fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values") def test_sunburst_treemap_with_path_non_rectangular(): @@ -165,7 +168,18 @@ def test_sunburst_treemap_with_path_non_rectangular(): "Finance", "Finance", ] - regions = ["North", "North", "North", "North", "North", "South", "South", "South", "South", "South"] + regions = [ + "North", + "North", + "North", + "North", + "North", + "South", + "South", + "South", + "South", + "South", + ] values = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] df = pd.DataFrame( dict(vendors=vendors, sectors=sectors, regions=regions, values=values) From c2332200cab2735ee437e867e7273f2efdfa4388 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 18 Dec 2019 15:25:12 -0500 Subject: [PATCH 08/33] corrected bug --- packages/python/plotly/plotly/express/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 21a54068471..892933fa28c 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1018,7 +1018,7 @@ def process_dataframe_hierarchy(args): path = args["path"] if args["values"]: try: - df["values"] = pd.to_numeric(df["values"]) + df[args["values"]] = pd.to_numeric(df[args["values"]]) except ValueError: raise ValueError( "Column `%s` of `df` could not be converted to a numerical data type." From edefabf2f1ce1b2dd0e1f45213624aa3fbe0fdc2 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 18 Dec 2019 15:30:20 -0500 Subject: [PATCH 09/33] treemap branchvalues --- packages/python/plotly/plotly/express/_chart_types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index e47eb7682da..be982e404db 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -1334,7 +1334,8 @@ def treemap( "Either `path` should be provided, or `ids` and `parents`." "These parameters are mutually exclusive and cannot be passed together." ) - + if path is not None and branchvalues is None: + branchvalues = "total" return make_figure( args=locals(), constructor=go.Treemap, From 2952fe608973e8db30858be9089384c2c4b3171d Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 17 Jan 2020 15:52:25 -0500 Subject: [PATCH 10/33] path is now from root to leaves --- doc/python/sunburst-charts.md | 4 ++-- packages/python/plotly/plotly/express/_core.py | 2 +- .../tests/test_core/test_px/test_px_functions.py | 16 ++++++++-------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/doc/python/sunburst-charts.md b/doc/python/sunburst-charts.md index d4fab3adbaa..02e64cb557e 100644 --- a/doc/python/sunburst-charts.md +++ b/doc/python/sunburst-charts.md @@ -69,7 +69,7 @@ Hierarchical data are often stored as a rectangular dataframe, with different co ```python import plotly.express as px df = px.data.tips() -fig = px.sunburst(df, path=['sex', 'time', 'day'], values='total_bill') +fig = px.sunburst(df, path=['day', 'time', 'sex'], values='total_bill') fig.show() ``` @@ -90,7 +90,7 @@ df = pd.DataFrame( dict(vendors=vendors, sectors=sectors, regions=regions, sales=sales) ) print(df) -fig = px.sunburst(df, path=['vendors', 'sectors', 'regions'], values='sales') +fig = px.sunburst(df, path=['regions', 'sectors', 'vendors'], values='sales') fig.show() ``` diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index f214ffe9c3c..43e26af4415 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1015,7 +1015,7 @@ def process_dataframe_hierarchy(args): Build dataframe for sunburst or treemap when the path argument is provided. """ df = args["data_frame"] - path = args["path"] + path = args["path"][::-1] if args["values"]: try: df[args["values"]] = pd.to_numeric(df[args["values"]]) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 749da4e572c..ee4f41d6a53 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -130,28 +130,27 @@ def test_sunburst_treemap_with_path(): df = pd.DataFrame( dict(vendors=vendors, sectors=sectors, regions=regions, values=values) ) + path = ["regions", "sectors", "vendors"] # No values - fig = px.sunburst(df, path=["vendors", "sectors", "regions"]) + fig = px.sunburst(df, path=path) assert fig.data[0].branchvalues == "total" # Values passed - fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values") + fig = px.sunburst(df, path=path, values="values") assert fig.data[0].branchvalues == "total" assert fig.data[0].values[-1] == np.sum(values) # Values passed - fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values") + fig = px.sunburst(df, path=path, values="values") assert fig.data[0].branchvalues == "total" assert fig.data[0].values[-1] == np.sum(values) # Continuous colorscale - fig = px.sunburst( - df, path=["vendors", "sectors", "regions"], values="values", color="values" - ) + fig = px.sunburst(df, path=path, values="values", color="values") assert "coloraxis" in fig.data[0].marker assert np.all(np.array(fig.data[0].marker.colors) == np.array(fig.data[0].values)) # Error when values cannot be converted to numerical data type df["values"] = ["1 000", "3 000", "2", "4", "2", "2", "1 000", "4 000"] msg = "Column `values` of `df` could not be converted to a numerical data type." with pytest.raises(ValueError, match=msg): - fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values") + fig = px.sunburst(df, path=path, values="values") def test_sunburst_treemap_with_path_non_rectangular(): @@ -184,7 +183,8 @@ def test_sunburst_treemap_with_path_non_rectangular(): df = pd.DataFrame( dict(vendors=vendors, sectors=sectors, regions=regions, values=values) ) - fig = px.sunburst(df, path=["vendors", "sectors", "regions"], values="values") + path = ["regions", "sectors", "vendors"] + fig = px.sunburst(df, path=path, values="values") assert fig.data[0].values[-1] == np.sum(values) From c6b7243d2024325cdecd81e4711c7503b58afdc3 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Sat, 18 Jan 2020 18:57:56 -0500 Subject: [PATCH 11/33] removed EPS hack --- packages/python/plotly/plotly/express/_core.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 43e26af4415..a9f9e97a76d 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1051,9 +1051,6 @@ def process_dataframe_hierarchy(args): elif cols: for col in cols: df_tree[col] = "n/a" - if args["values"]: - # EPS hack, to be removed - df_tree[args["values"]] = dfg[args["values"]] - 1.0e-10 df_all_trees = df_all_trees.append(df_tree, ignore_index=True) # Root node From be3b622ecaaf5e8828f2cc6a77c66f3dbe2e9988 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 20 Jan 2020 14:09:11 -0500 Subject: [PATCH 12/33] working version for continuous color --- .../python/plotly/plotly/express/_core.py | 52 ++++++++++++------- .../test_core/test_px/test_px_functions.py | 22 ++++++-- 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index a9f9e97a76d..829d0c20b21 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1017,6 +1017,17 @@ def process_dataframe_hierarchy(args): df = args["data_frame"] path = args["path"][::-1] if args["values"]: + # Define weighted mean lambda, using value column + lambda_wm = lambda x: np.average(x, weights=df.loc[x.index, args["values"]]) + # Define aggregation function to be used on groupby objects + if args["color"]: + aggfunc_color = "sum" if args["color"] == args["values"] else lambda_wm + agg_f = { + args["values"]: pd.NamedAgg(column=args["values"], aggfunc="sum"), + args["color"]: pd.NamedAgg(column=args["color"], aggfunc=aggfunc_color), + } + else: + agg_f = {args["values"]: pd.NamedAgg(column=args["values"], aggfunc="sum")} try: df[args["values"]] = pd.to_numeric(df[args["values"]]) except ValueError: @@ -1024,6 +1035,20 @@ def process_dataframe_hierarchy(args): "Column `%s` of `df` could not be converted to a numerical data type." % args["values"] ) + else: + if args["color"]: # color passed but not value + # we need a count column for the weighted mean of color + # trick to be sure the col name is unused: take the sum of existing names + count_colname = "".join([str(el) for el in list(df.columns)]) + # we can modify df because it's a copy of the px argument + df[count_colname] = 1 + lambda_wm = lambda x: np.average(x, weights=df.loc[x.index, count_colname]) + agg_f = { + args["color"]: pd.NamedAgg(column=args["color"], aggfunc=lambda_wm), + count_colname: pd.NamedAgg(column=count_colname, aggfunc="sum"), + } + else: + agg_f = {} # Other columns (for color, hover_data, custom_data etc.) cols = list(set(df.columns).difference(path)) df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols) @@ -1032,7 +1057,10 @@ def process_dataframe_hierarchy(args): df_all_trees[col] = df_all_trees[col].astype(df[col].dtype) for i, level in enumerate(path): df_tree = pd.DataFrame(columns=df_all_trees.columns) - dfg = df.groupby(path[i:]).sum(numerical_only=True) + if not agg_f: + dfg = df.groupby(path[i:]).sum(numerical_only=True) + else: + dfg = df.groupby(path[i:]).agg(**agg_f) dfg = dfg.reset_index() df_tree["labels"] = dfg[level].copy().astype(str) df_tree["parent"] = "" @@ -1044,29 +1072,13 @@ def process_dataframe_hierarchy(args): df_tree["id"] += dfg[path[j]].copy().astype(str) j += 1 else: - df_tree["parent"] = "total" + df_tree["parent"] = "" - if i == 0 and cols: + if cols: df_tree[cols] = dfg[cols] - elif cols: - for col in cols: - df_tree[col] = "n/a" df_all_trees = df_all_trees.append(df_tree, ignore_index=True) - # Root node - total_dict = { - "labels": "total", - "id": "total", - "parent": "", - } - for col in cols: - if not col == args["values"]: - total_dict[col] = "n/a" - if col == args["values"]: - total_dict[col] = df[col].sum() - total = pd.Series(total_dict) - - df_all_trees = df_all_trees.append(total, ignore_index=True) + print(df_all_trees) # Now modify arguments args["data_frame"] = df_all_trees args["path"] = None diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index ee4f41d6a53..c83216b2ded 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -127,10 +127,17 @@ def test_sunburst_treemap_with_path(): ] regions = ["North", "North", "North", "North", "South", "South", "South", "South"] values = [1, 3, 2, 4, 2, 2, 1, 4] + total = ["total",] * 8 df = pd.DataFrame( - dict(vendors=vendors, sectors=sectors, regions=regions, values=values) + dict( + vendors=vendors, + sectors=sectors, + regions=regions, + values=values, + total=total, + ) ) - path = ["regions", "sectors", "vendors"] + path = ["total", "regions", "sectors", "vendors"] # No values fig = px.sunburst(df, path=path) assert fig.data[0].branchvalues == "total" @@ -180,10 +187,17 @@ def test_sunburst_treemap_with_path_non_rectangular(): "South", ] values = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] + total = ["total",] * 10 df = pd.DataFrame( - dict(vendors=vendors, sectors=sectors, regions=regions, values=values) + dict( + vendors=vendors, + sectors=sectors, + regions=regions, + values=values, + total=total, + ) ) - path = ["regions", "sectors", "vendors"] + path = ["total", "regions", "sectors", "vendors"] fig = px.sunburst(df, path=path, values="values") assert fig.data[0].values[-1] == np.sum(values) From 7f2920ba1db5a8e989be9753419e82ab44f355dc Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 20 Jan 2020 16:50:53 -0500 Subject: [PATCH 13/33] new tests and more readable code, also added hover support --- .../python/plotly/plotly/express/_core.py | 55 ++++++++++--------- packages/python/plotly/plotly/express/_doc.py | 7 ++- .../test_core/test_px/test_px_functions.py | 46 ++++++++++++++++ 3 files changed, 82 insertions(+), 26 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 829d0c20b21..d4680fb6df9 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -890,9 +890,6 @@ def build_dataframe(args, attrables, array_attrables): else: df_output[df_input.columns] = df_input[df_input.columns] - # This should be improved + tested - HACK - if "path" in args and args["path"] is not None: - df_output[args["path"]] = df_input[args["path"]] # Loop over possible arguments for field_name in attrables: # Massaging variables @@ -1016,18 +1013,12 @@ def process_dataframe_hierarchy(args): """ df = args["data_frame"] path = args["path"][::-1] + + # ------------ Define aggregation functions -------------------------------- + lambda_discrete = lambda x: x[0] if len(x) == 1 else "" + agg_f = {} + aggfunc_color = None if args["values"]: - # Define weighted mean lambda, using value column - lambda_wm = lambda x: np.average(x, weights=df.loc[x.index, args["values"]]) - # Define aggregation function to be used on groupby objects - if args["color"]: - aggfunc_color = "sum" if args["color"] == args["values"] else lambda_wm - agg_f = { - args["values"]: pd.NamedAgg(column=args["values"], aggfunc="sum"), - args["color"]: pd.NamedAgg(column=args["color"], aggfunc=aggfunc_color), - } - else: - agg_f = {args["values"]: pd.NamedAgg(column=args["values"], aggfunc="sum")} try: df[args["values"]] = pd.to_numeric(df[args["values"]]) except ValueError: @@ -1035,6 +1026,11 @@ def process_dataframe_hierarchy(args): "Column `%s` of `df` could not be converted to a numerical data type." % args["values"] ) + + if args["color"]: + if args["color"] == args["values"]: + aggfunc_color = "sum" + count_colname = args["values"] else: if args["color"]: # color passed but not value # we need a count column for the weighted mean of color @@ -1042,17 +1038,27 @@ def process_dataframe_hierarchy(args): count_colname = "".join([str(el) for el in list(df.columns)]) # we can modify df because it's a copy of the px argument df[count_colname] = 1 - lambda_wm = lambda x: np.average(x, weights=df.loc[x.index, count_colname]) - agg_f = { - args["color"]: pd.NamedAgg(column=args["color"], aggfunc=lambda_wm), - count_colname: pd.NamedAgg(column=count_colname, aggfunc="sum"), - } - else: - agg_f = {} - # Other columns (for color, hover_data, custom_data etc.) + + if args["color"]: + if df[args["color"]].dtype.kind not in "bifc": + aggfunc_color = lambda_discrete + elif not aggfunc_color: + aggfunc_color = lambda x: np.average( + x, weights=df.loc[x.index, count_colname] + ) + agg_f[args["color"]] = pd.NamedAgg(column=args["color"], aggfunc=aggfunc_color) + if args["color"] or args["values"]: + agg_f[count_colname] = pd.NamedAgg(column=count_colname, aggfunc="sum") + + # Other columns (for color, hover_data, custom_data etc.) cols = list(set(df.columns).difference(path)) + for col in cols: # for hover_data, custom_data etc. + if col not in agg_f: + agg_f[col] = pd.NamedAgg(column=col, aggfunc=lambda_discrete) + # ---------------------------------------------------------------------------- + df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols) - # Set column type here (useful for continuous vs discrete colorscale) + # Set column type here (useful for continuous vs discrete colorscale) for col in cols: df_all_trees[col] = df_all_trees[col].astype(df[col].dtype) for i, level in enumerate(path): @@ -1078,7 +1084,6 @@ def process_dataframe_hierarchy(args): df_tree[cols] = dfg[cols] df_all_trees = df_all_trees.append(df_tree, ignore_index=True) - print(df_all_trees) # Now modify arguments args["data_frame"] = df_all_trees args["path"] = None @@ -1096,7 +1101,7 @@ def infer_config(args, constructor, trace_patch): + ["names", "values", "parents", "ids"] + ["error_x", "error_x_minus"] + ["error_y", "error_y_minus", "error_z", "error_z_minus"] - + ["lat", "lon", "locations", "animation_group"] + + ["lat", "lon", "locations", "animation_group", "path"] ) array_attrables = ["dimensions", "custom_data", "hover_data", "path"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 2a93b3c61d9..b3c6d39dc7a 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -86,7 +86,12 @@ colref_desc, "Values from this column or array_like are used to set ids of sectors", ], - path=[colref_type, colref_desc], + path=[ + colref_list_type, + colref_list_desc, + "List of columns names or columns of a rectangular dataframe defining the hierarchy of sectors, from root to leaves.", + "An error is raised if path AND ids or parents is passed", + ], lat=[ colref_type, colref_desc, diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index c83216b2ded..24000ac3e3f 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -158,6 +158,52 @@ def test_sunburst_treemap_with_path(): msg = "Column `values` of `df` could not be converted to a numerical data type." with pytest.raises(ValueError, match=msg): fig = px.sunburst(df, path=path, values="values") + # path is a mixture of column names and array-like + path = [df.total, "regions", df.sectors, "vendors"] + fig = px.sunburst(df, path=path) + assert fig.data[0].branchvalues == "total" + + +def test_sunburst_treemap_with_path_color(): + vendors = ["A", "B", "C", "D", "E", "F", "G", "H"] + sectors = [ + "Tech", + "Tech", + "Finance", + "Finance", + "Tech", + "Tech", + "Finance", + "Finance", + ] + regions = ["North", "North", "North", "North", "South", "South", "South", "South"] + values = [1, 3, 2, 4, 2, 2, 1, 4] + calls = [8, 2, 1, 3, 2, 2, 4, 1] + total = ["total",] * 8 + df = pd.DataFrame( + dict( + vendors=vendors, + sectors=sectors, + regions=regions, + values=values, + total=total, + calls=calls, + ) + ) + path = ["total", "regions", "sectors", "vendors"] + fig = px.sunburst(df, path=path, values="values", color="calls") + colors = fig.data[0].marker.colors + assert np.all(np.array(colors[:8]) == np.array(calls)) + fig = px.sunburst(df, path=path, color="calls") + colors = fig.data[0].marker.colors + assert np.all(np.array(colors[:8]) == np.array(calls)) + + # Hover info + df["hover"] = [el.lower() for el in vendors] + fig = px.sunburst(df, path=path, color="calls", hover_data=["hover"]) + custom = fig.data[0].customdata.ravel() + assert np.all(custom[:8] == df["hover"]) + assert np.all(custom[8:] == "") def test_sunburst_treemap_with_path_non_rectangular(): From 85193026b84bc3941c97cce42f5fd75d3300cb63 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 20 Jan 2020 17:13:16 -0500 Subject: [PATCH 14/33] updated docs --- doc/python/sunburst-charts.md | 15 +++++++++++ doc/python/treemaps.md | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/doc/python/sunburst-charts.md b/doc/python/sunburst-charts.md index 02e64cb557e..522d82ac9b9 100644 --- a/doc/python/sunburst-charts.md +++ b/doc/python/sunburst-charts.md @@ -73,6 +73,21 @@ fig = px.sunburst(df, path=['day', 'time', 'sex'], values='total_bill') fig.show() ``` +### Sunburst of a rectangular DataFrame with continuous color argument in px.sunburst + +If a `color` argument is passed, the color of a node is computed as the average of the color values of its children, weighted by their values. + +```python +import plotly.express as px +import numpy as np +df = px.data.gapminder().query("year == 2007") +fig = px.sunburst(df, path=['continent', 'country'], values='pop', + color='lifeExp', hover_data=['iso_alpha'], + color_continuous_scale='RdBu', + color_continuous_midpoint=np.average(df['lifeExp'], weights=df['pop'])) +fig.show() +``` + ### Rectangular data with missing values If the dataset is not fully rectangular, missing values should be supplied as `None`. diff --git a/doc/python/treemaps.md b/doc/python/treemaps.md index 02167e3cdb8..5f6590e5981 100644 --- a/doc/python/treemaps.md +++ b/doc/python/treemaps.md @@ -51,6 +51,55 @@ fig = px.treemap( fig.show() ``` +### Treemap of a rectangular DataFrame with plotly.express + +Hierarchical data are often stored as a rectangular dataframe, with different columns corresponding to different levels of the hierarchy. `px.treemap` can take a `path` parameter corresponding to a list of columns. Note that `id` and `parent` should not be provided if `path` is given. + +```python +import plotly.express as px +df = px.data.tips() +fig = px.treemap(df, path=['day', 'time', 'sex'], values='total_bill') +fig.show() +``` + +### Treemap of a rectangular DataFrame with continuous color argument in px.treemap + +If a `color` argument is passed, the color of a node is computed as the average of the color values of its children, weighted by their values. + +```python +import plotly.express as px +import numpy as np +df = px.data.gapminder().query("year == 2007") +fig = px.treemap(df, path=['continent', 'country'], values='pop', + color='lifeExp', hover_data=['iso_alpha'], + color_continuous_scale='RdBu', + color_continuous_midpoint=np.average(df['lifeExp'], weights=df['pop'])) +fig.show() +``` + +### Rectangular data with missing values + +If the dataset is not fully rectangular, missing values should be supplied as `None`. + +```python +import plotly.express as px +import pandas as pd +vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None] +sectors = ["Tech", "Tech", "Finance", "Finance", None, + "Tech", "Tech", "Finance", "Finance", "Finance"] +regions = ["North", "North", "North", "North", "North", + "South", "South", "South", "South", "South"] +sales = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] +df = pd.DataFrame( + dict(vendors=vendors, sectors=sectors, regions=regions, sales=sales) +) +print(df) +fig = px.treemap(df, path=['regions', 'sectors', 'vendors'], values='sales') +fig.show() +``` + + + ### Basic Treemap with go.Treemap If Plotly Express does not provide a good starting point, it is also possible to use the more generic `go.Treemap` function from `plotly.graph_objects`. From 437bbd72c77dce81952eb8b24e485d0743923d2f Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 20 Jan 2020 17:38:07 -0500 Subject: [PATCH 15/33] removed named agg which is valid only starting from pandas 0.25 --- packages/python/plotly/plotly/express/_core.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d4680fb6df9..939b9bb45b9 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1007,6 +1007,14 @@ def build_dataframe(args, attrables, array_attrables): return args +def _named_agg(colname, aggfunc, mode="old_pandas"): + if mode == "old_pandas": + return (colname, aggfunc) + else: + # switch to this mode when tuples become deprecated + return pd.NamedAgg(colname, aggfunc) + + def process_dataframe_hierarchy(args): """ Build dataframe for sunburst or treemap when the path argument is provided. @@ -1046,15 +1054,15 @@ def process_dataframe_hierarchy(args): aggfunc_color = lambda x: np.average( x, weights=df.loc[x.index, count_colname] ) - agg_f[args["color"]] = pd.NamedAgg(column=args["color"], aggfunc=aggfunc_color) + agg_f[args["color"]] = _named_agg(colname=args["color"], aggfunc=aggfunc_color) if args["color"] or args["values"]: - agg_f[count_colname] = pd.NamedAgg(column=count_colname, aggfunc="sum") + agg_f[count_colname] = _named_agg(colname=count_colname, aggfunc="sum") # Other columns (for color, hover_data, custom_data etc.) cols = list(set(df.columns).difference(path)) for col in cols: # for hover_data, custom_data etc. if col not in agg_f: - agg_f[col] = pd.NamedAgg(column=col, aggfunc=lambda_discrete) + agg_f[col] = _named_agg(colname=col, aggfunc=lambda_discrete) # ---------------------------------------------------------------------------- df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols) From fb9d9922915ed6257681addb2da377db61fab8ae Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 20 Jan 2020 18:39:39 -0500 Subject: [PATCH 16/33] version hopefully compatible with older pandas --- packages/python/plotly/plotly/express/_core.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 939b9bb45b9..b9aa7c88a45 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1007,14 +1007,6 @@ def build_dataframe(args, attrables, array_attrables): return args -def _named_agg(colname, aggfunc, mode="old_pandas"): - if mode == "old_pandas": - return (colname, aggfunc) - else: - # switch to this mode when tuples become deprecated - return pd.NamedAgg(colname, aggfunc) - - def process_dataframe_hierarchy(args): """ Build dataframe for sunburst or treemap when the path argument is provided. @@ -1054,15 +1046,15 @@ def process_dataframe_hierarchy(args): aggfunc_color = lambda x: np.average( x, weights=df.loc[x.index, count_colname] ) - agg_f[args["color"]] = _named_agg(colname=args["color"], aggfunc=aggfunc_color) + agg_f[args["color"]] = aggfunc_color if args["color"] or args["values"]: - agg_f[count_colname] = _named_agg(colname=count_colname, aggfunc="sum") + agg_f[count_colname] = "sum" # Other columns (for color, hover_data, custom_data etc.) cols = list(set(df.columns).difference(path)) for col in cols: # for hover_data, custom_data etc. if col not in agg_f: - agg_f[col] = _named_agg(colname=col, aggfunc=lambda_discrete) + agg_f[col] = lambda_discrete # ---------------------------------------------------------------------------- df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols) @@ -1074,7 +1066,7 @@ def process_dataframe_hierarchy(args): if not agg_f: dfg = df.groupby(path[i:]).sum(numerical_only=True) else: - dfg = df.groupby(path[i:]).agg(**agg_f) + dfg = df.groupby(path[i:]).agg(agg_f) dfg = dfg.reset_index() df_tree["labels"] = dfg[level].copy().astype(str) df_tree["parent"] = "" From a57b0279f3136d147bb5548b0d26cc02fd4241b8 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 20 Jan 2020 21:40:07 -0500 Subject: [PATCH 17/33] still debugging --- packages/python/plotly/plotly/express/_core.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index b9aa7c88a45..f50356ab4c4 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1007,6 +1007,13 @@ def build_dataframe(args, attrables, array_attrables): return args +def _discrete_agg(x): + if len(x) == 1: + return x.iloc[0] + else: + return "" + + def process_dataframe_hierarchy(args): """ Build dataframe for sunburst or treemap when the path argument is provided. @@ -1015,7 +1022,7 @@ def process_dataframe_hierarchy(args): path = args["path"][::-1] # ------------ Define aggregation functions -------------------------------- - lambda_discrete = lambda x: x[0] if len(x) == 1 else "" + lambda_discrete = _discrete_agg agg_f = {} aggfunc_color = None if args["values"]: From bf8da4b576f8a87374e23bb643eb0f1e16a18fa9 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 21 Jan 2020 16:32:48 -0500 Subject: [PATCH 18/33] do not use lambdas --- .../python/plotly/plotly/express/_core.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index f50356ab4c4..307e7a6cb4d 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1007,13 +1007,6 @@ def build_dataframe(args, attrables, array_attrables): return args -def _discrete_agg(x): - if len(x) == 1: - return x.iloc[0] - else: - return "" - - def process_dataframe_hierarchy(args): """ Build dataframe for sunburst or treemap when the path argument is provided. @@ -1022,7 +1015,12 @@ def process_dataframe_hierarchy(args): path = args["path"][::-1] # ------------ Define aggregation functions -------------------------------- - lambda_discrete = _discrete_agg + def aggfunc_discrete(x): + if len(x) == 1: + return x.iloc[0] + else: + return "" + agg_f = {} aggfunc_color = None if args["values"]: @@ -1048,11 +1046,13 @@ def process_dataframe_hierarchy(args): if args["color"]: if df[args["color"]].dtype.kind not in "bifc": - aggfunc_color = lambda_discrete + aggfunc_color = aggfunc_discrete elif not aggfunc_color: - aggfunc_color = lambda x: np.average( - x, weights=df.loc[x.index, count_colname] - ) + + def aggfunc_continuous(x): + return np.average(x, weights=df.loc[x.index, count_colname]) + + aggfunc_color = aggfunc_continuous agg_f[args["color"]] = aggfunc_color if args["color"] or args["values"]: agg_f[count_colname] = "sum" @@ -1061,7 +1061,7 @@ def process_dataframe_hierarchy(args): cols = list(set(df.columns).difference(path)) for col in cols: # for hover_data, custom_data etc. if col not in agg_f: - agg_f[col] = lambda_discrete + agg_f[col] = aggfunc_discrete # ---------------------------------------------------------------------------- df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols) From 9e238908980fb62062f5ad343ebfb5b9646ba811 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 21 Jan 2020 16:35:16 -0500 Subject: [PATCH 19/33] removed redundant else --- packages/python/plotly/plotly/express/_core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 307e7a6cb4d..c8294cd58e0 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1084,8 +1084,6 @@ def aggfunc_continuous(x): df_tree["parent"] += dfg[path[j]].copy().astype(str) df_tree["id"] += dfg[path[j]].copy().astype(str) j += 1 - else: - df_tree["parent"] = "" if cols: df_tree[cols] = dfg[cols] From f67602f75c81f7ee67a9e7b26018e013ea854aae Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 21 Jan 2020 21:34:00 -0500 Subject: [PATCH 20/33] discrete color --- packages/python/plotly/plotly/express/_core.py | 4 ++++ .../plotly/tests/test_core/test_px/test_px_functions.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index c8294cd58e0..6571cec4c85 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1014,6 +1014,10 @@ def process_dataframe_hierarchy(args): df = args["data_frame"] path = args["path"][::-1] + if args["color"] and args["color"] in path: + series_to_copy = df[args["color"]] + args["color"] = str(args["color"]) + "additional_col_for_px" + df[args["color"]] = series_to_copy # ------------ Define aggregation functions -------------------------------- def aggfunc_discrete(x): if len(x) == 1: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 24000ac3e3f..78faf813024 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -205,6 +205,11 @@ def test_sunburst_treemap_with_path_color(): assert np.all(custom[:8] == df["hover"]) assert np.all(custom[8:] == "") + # Discrete color + fig = px.sunburst(df, path=path, color="vendors") + assert len(np.unique(fig.data[0].marker.colors[:8])) == 8 + assert len(np.unique(fig.data[0].marker.colors[8:])) == 1 + def test_sunburst_treemap_with_path_non_rectangular(): vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None] From 6b6a1052eb91a5c8087cc367e8f92be0059ede71 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 21 Jan 2020 21:51:54 -0500 Subject: [PATCH 21/33] always add a count column when no values column is passed --- packages/python/plotly/plotly/express/_core.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 6571cec4c85..1e352afe52e 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1041,12 +1041,16 @@ def aggfunc_discrete(x): aggfunc_color = "sum" count_colname = args["values"] else: - if args["color"]: # color passed but not value - # we need a count column for the weighted mean of color - # trick to be sure the col name is unused: take the sum of existing names - count_colname = "".join([str(el) for el in list(df.columns)]) - # we can modify df because it's a copy of the px argument - df[count_colname] = 1 + # we need a count column for the first groupby and the weighted mean of color + # trick to be sure the col name is unused: take the sum of existing names + count_colname = ( + "count" + if "count" not in df.columns + else "".join([str(el) for el in list(df.columns)]) + ) + # we can modify df because it's a copy of the px argument + df[count_colname] = 1 + args["values"] = count_colname if args["color"]: if df[args["color"]].dtype.kind not in "bifc": From 99967319796854c7a8d4de5a034f42f5d7cad970 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 21 Jan 2020 22:00:41 -0500 Subject: [PATCH 22/33] removed if which is not required any more --- packages/python/plotly/plotly/express/_core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 1e352afe52e..f7ccdf2e366 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1051,6 +1051,7 @@ def aggfunc_discrete(x): # we can modify df because it's a copy of the px argument df[count_colname] = 1 args["values"] = count_colname + agg_f[count_colname] = "sum" if args["color"]: if df[args["color"]].dtype.kind not in "bifc": @@ -1062,8 +1063,6 @@ def aggfunc_continuous(x): aggfunc_color = aggfunc_continuous agg_f[args["color"]] = aggfunc_color - if args["color"] or args["values"]: - agg_f[count_colname] = "sum" # Other columns (for color, hover_data, custom_data etc.) cols = list(set(df.columns).difference(path)) From f3e7e279c00667095f9df1c36ac838f755224e51 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 21 Jan 2020 22:16:06 -0500 Subject: [PATCH 23/33] nicer labels with / --- packages/python/plotly/plotly/express/_core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index f7ccdf2e366..7d06464a305 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1082,6 +1082,7 @@ def aggfunc_continuous(x): else: dfg = df.groupby(path[i:]).agg(agg_f) dfg = dfg.reset_index() + # Path label massaging df_tree["labels"] = dfg[level].copy().astype(str) df_tree["parent"] = "" df_tree["id"] = dfg[level].copy().astype(str) @@ -1089,7 +1090,9 @@ def aggfunc_continuous(x): j = i + 1 while j < len(path): df_tree["parent"] += dfg[path[j]].copy().astype(str) - df_tree["id"] += dfg[path[j]].copy().astype(str) + if j < len(path) - 1: + df_tree["parent"] += "/" + df_tree["id"] += "/" + dfg[path[j]].copy().astype(str) j += 1 if cols: From 8cd227a898dbfb48f86266b5d60516f57c714776 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 21 Jan 2020 22:21:04 -0500 Subject: [PATCH 24/33] simplified code --- packages/python/plotly/plotly/express/_core.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 7d06464a305..5536b37559f 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1077,10 +1077,7 @@ def aggfunc_continuous(x): df_all_trees[col] = df_all_trees[col].astype(df[col].dtype) for i, level in enumerate(path): df_tree = pd.DataFrame(columns=df_all_trees.columns) - if not agg_f: - dfg = df.groupby(path[i:]).sum(numerical_only=True) - else: - dfg = df.groupby(path[i:]).agg(agg_f) + dfg = df.groupby(path[i:]).agg(agg_f) dfg = dfg.reset_index() # Path label massaging df_tree["labels"] = dfg[level].copy().astype(str) From 8b66c908fab6b39cdba7a8f81cec6f924a4c3df2 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 22 Jan 2020 09:24:08 -0500 Subject: [PATCH 25/33] better id labels --- packages/python/plotly/plotly/express/_core.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 5536b37559f..dc465747aa7 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1086,12 +1086,11 @@ def aggfunc_continuous(x): if i < len(path) - 1: j = i + 1 while j < len(path): - df_tree["parent"] += dfg[path[j]].copy().astype(str) - if j < len(path) - 1: - df_tree["parent"] += "/" - df_tree["id"] += "/" + dfg[path[j]].copy().astype(str) + df_tree["parent"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"] + df_tree["id"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["id"] j += 1 + df_tree["parent"] = df_tree["parent"].str.rstrip('/') if cols: df_tree[cols] = dfg[cols] df_all_trees = df_all_trees.append(df_tree, ignore_index=True) From 19b81ac12f4785d920d1b4ffa7f4f076e4b64f44 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 22 Jan 2020 09:33:09 -0500 Subject: [PATCH 26/33] discrete colors --- packages/python/plotly/plotly/express/_core.py | 13 ++++++++----- .../tests/test_core/test_px/test_px_functions.py | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index dc465747aa7..b09399734c4 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1020,10 +1020,11 @@ def process_dataframe_hierarchy(args): df[args["color"]] = series_to_copy # ------------ Define aggregation functions -------------------------------- def aggfunc_discrete(x): - if len(x) == 1: - return x.iloc[0] + uniques = x.unique() + if len(uniques) == 1: + return uniques[0] else: - return "" + return "(?)" agg_f = {} aggfunc_color = None @@ -1086,11 +1087,13 @@ def aggfunc_continuous(x): if i < len(path) - 1: j = i + 1 while j < len(path): - df_tree["parent"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"] + df_tree["parent"] = ( + dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"] + ) df_tree["id"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["id"] j += 1 - df_tree["parent"] = df_tree["parent"].str.rstrip('/') + df_tree["parent"] = df_tree["parent"].str.rstrip("/") if cols: df_tree[cols] = dfg[cols] df_all_trees = df_all_trees.append(df_tree, ignore_index=True) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 78faf813024..d0c259aace8 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -203,7 +203,7 @@ def test_sunburst_treemap_with_path_color(): fig = px.sunburst(df, path=path, color="calls", hover_data=["hover"]) custom = fig.data[0].customdata.ravel() assert np.all(custom[:8] == df["hover"]) - assert np.all(custom[8:] == "") + assert np.all(custom[8:] == "(?)") # Discrete color fig = px.sunburst(df, path=path, color="vendors") From ba6ec191af7b0cfb0b1efac9b2e43e431cdc8fcd Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 22 Jan 2020 11:53:05 -0500 Subject: [PATCH 27/33] raise ValueError for non-leaves with None --- doc/python/sunburst-charts.md | 8 ++++---- doc/python/treemaps.md | 7 ++----- packages/python/plotly/plotly/express/_core.py | 16 ++++++++++++++++ .../tests/test_core/test_px/test_px_functions.py | 4 ++++ 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/doc/python/sunburst-charts.md b/doc/python/sunburst-charts.md index 522d82ac9b9..60fcc5751b3 100644 --- a/doc/python/sunburst-charts.md +++ b/doc/python/sunburst-charts.md @@ -69,7 +69,7 @@ Hierarchical data are often stored as a rectangular dataframe, with different co ```python import plotly.express as px df = px.data.tips() -fig = px.sunburst(df, path=['day', 'time', 'sex'], values='total_bill') +fig = px.sunburst(df, path=['day', 'time', 'sex'], values='total_bill', color='tip') fig.show() ``` @@ -90,14 +90,14 @@ fig.show() ### Rectangular data with missing values -If the dataset is not fully rectangular, missing values should be supplied as `None`. +If the dataset is not fully rectangular, missing values should be supplied as `None`. Note that the parents of `None` entries must be a leaf, i.e. it cannot have other children than `None` (otherwise a `ValueError` is raised). ```python import plotly.express as px import pandas as pd vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None] -sectors = ["Tech", "Tech", "Finance", "Finance", None, - "Tech", "Tech", "Finance", "Finance", "Finance"] +sectors = ["Tech", "Tech", "Finance", "Finance", "Other", + "Tech", "Tech", "Finance", "Finance", "Other"] regions = ["North", "North", "North", "North", "North", "South", "South", "South", "South", "South"] sales = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] diff --git a/doc/python/treemaps.md b/doc/python/treemaps.md index 5f6590e5981..8b1cc97d1f1 100644 --- a/doc/python/treemaps.md +++ b/doc/python/treemaps.md @@ -85,8 +85,8 @@ If the dataset is not fully rectangular, missing values should be supplied as `N import plotly.express as px import pandas as pd vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None] -sectors = ["Tech", "Tech", "Finance", "Finance", None, - "Tech", "Tech", "Finance", "Finance", "Finance"] +sectors = ["Tech", "Tech", "Finance", "Finance", "Other", + "Tech", "Tech", "Finance", "Finance", "Other"] regions = ["North", "North", "North", "North", "North", "South", "South", "South", "South", "South"] sales = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] @@ -97,9 +97,6 @@ print(df) fig = px.treemap(df, path=['regions', 'sectors', 'vendors'], values='sales') fig.show() ``` - - - ### Basic Treemap with go.Treemap If Plotly Express does not provide a good starting point, it is also possible to use the more generic `go.Treemap` function from `plotly.graph_objects`. diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index b09399734c4..28d63e17b0e 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1007,12 +1007,28 @@ def build_dataframe(args, attrables, array_attrables): return args +def _check_dataframe_all_leaves(df): + df_sorted = df.sort_values(by=list(df.columns)) + null_mask = df_sorted.isnull() + null_indices = null_mask.any(axis=1).to_numpy().nonzero()[0] + df_sorted[null_mask] = "" + row_strings = list(df_sorted.apply(lambda x: "".join(x), axis=1)) + for i, row in enumerate(row_strings[:-1]): + if row_strings[i + 1] in row and (i + 1) in null_indices: + raise ValueError( + "Non-leaves rows are not permitted in the dataframe \n", + df_sorted.iloc[i + 1], + "is not a leaf.", + ) + + def process_dataframe_hierarchy(args): """ Build dataframe for sunburst or treemap when the path argument is provided. """ df = args["data_frame"] path = args["path"][::-1] + _check_dataframe_all_leaves(df[path[::-1]]) if args["color"] and args["color"] in path: series_to_copy = df[args["color"]] diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index d0c259aace8..40d9b166cca 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -249,6 +249,10 @@ def test_sunburst_treemap_with_path_non_rectangular(): ) ) path = ["total", "regions", "sectors", "vendors"] + msg = "Non-leaves rows are not permitted in the dataframe" + with pytest.raises(ValueError, match=msg): + fig = px.sunburst(df, path=path, values="values") + df.loc[df["vendors"].isnull(), "sectors"] = "Other" fig = px.sunburst(df, path=path, values="values") assert fig.data[0].values[-1] == np.sum(values) From c0cbce0b6eef1f6d153ca57353de9a6c34e2b9a2 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 22 Jan 2020 12:40:58 -0500 Subject: [PATCH 28/33] other check --- packages/python/plotly/plotly/express/_core.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 28d63e17b0e..bad39a2b3a0 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1011,6 +1011,14 @@ def _check_dataframe_all_leaves(df): df_sorted = df.sort_values(by=list(df.columns)) null_mask = df_sorted.isnull() null_indices = null_mask.any(axis=1).to_numpy().nonzero()[0] + for null_row_index in null_indices: + row = null_mask.iloc[null_row_index] + indices = row.to_numpy().nonzero()[0] + if not row[indices[0] :].all(): + raise ValueError( + "None entries cannot have not-None children", + df_sorted.iloc[null_row_index], + ) df_sorted[null_mask] = "" row_strings = list(df_sorted.apply(lambda x: "".join(x), axis=1)) for i, row in enumerate(row_strings[:-1]): From 57503b4e74a0d6f3e134180f48118e6272486f34 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 22 Jan 2020 12:56:04 -0500 Subject: [PATCH 29/33] discrete color other comes first --- packages/python/plotly/plotly/express/_core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index bad39a2b3a0..0ff1fbdc8ac 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1122,6 +1122,9 @@ def aggfunc_continuous(x): df_tree[cols] = dfg[cols] df_all_trees = df_all_trees.append(df_tree, ignore_index=True) + if args["color"]: + df_all_trees = df_all_trees.sort_values(by=args["color"]) + # Now modify arguments args["data_frame"] = df_all_trees args["path"] = None From 0ab2afde2c27387b961529e069dc3e62687ce6b3 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 22 Jan 2020 13:10:47 -0500 Subject: [PATCH 30/33] fixed tests --- packages/python/plotly/plotly/express/_core.py | 4 +++- .../plotly/tests/test_core/test_px/test_px_functions.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 0ff1fbdc8ac..fee1c5b1ae0 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1037,6 +1037,7 @@ def process_dataframe_hierarchy(args): df = args["data_frame"] path = args["path"][::-1] _check_dataframe_all_leaves(df[path[::-1]]) + discrete_color = False if args["color"] and args["color"] in path: series_to_copy = df[args["color"]] @@ -1081,6 +1082,7 @@ def aggfunc_discrete(x): if args["color"]: if df[args["color"]].dtype.kind not in "bifc": aggfunc_color = aggfunc_discrete + discrete_color = True elif not aggfunc_color: def aggfunc_continuous(x): @@ -1122,7 +1124,7 @@ def aggfunc_continuous(x): df_tree[cols] = dfg[cols] df_all_trees = df_all_trees.append(df_tree, ignore_index=True) - if args["color"]: + if args["color"] and discrete_color: df_all_trees = df_all_trees.sort_values(by=args["color"]) # Now modify arguments diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 40d9b166cca..2a43dab7234 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -207,8 +207,7 @@ def test_sunburst_treemap_with_path_color(): # Discrete color fig = px.sunburst(df, path=path, color="vendors") - assert len(np.unique(fig.data[0].marker.colors[:8])) == 8 - assert len(np.unique(fig.data[0].marker.colors[8:])) == 1 + assert len(np.unique(fig.data[0].marker.colors)) == 9 def test_sunburst_treemap_with_path_non_rectangular(): From 0d86998e9e78f08110d06341dce58ae1783b8471 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 22 Jan 2020 13:32:25 -0500 Subject: [PATCH 31/33] hover --- doc/python/sunburst-charts.md | 2 +- packages/python/plotly/plotly/express/_core.py | 12 ++++++++++++ .../tests/test_core/test_px/test_px_functions.py | 8 ++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/doc/python/sunburst-charts.md b/doc/python/sunburst-charts.md index 60fcc5751b3..db659f99092 100644 --- a/doc/python/sunburst-charts.md +++ b/doc/python/sunburst-charts.md @@ -69,7 +69,7 @@ Hierarchical data are often stored as a rectangular dataframe, with different co ```python import plotly.express as px df = px.data.tips() -fig = px.sunburst(df, path=['day', 'time', 'sex'], values='total_bill', color='tip') +fig = px.sunburst(df, path=['day', 'time', 'sex'], values='total_bill') fig.show() ``` diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index fee1c5b1ae0..0009a23d720 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1043,6 +1043,18 @@ def process_dataframe_hierarchy(args): series_to_copy = df[args["color"]] args["color"] = str(args["color"]) + "additional_col_for_px" df[args["color"]] = series_to_copy + if args["hover_data"]: + for col_name in args["hover_data"]: + if col_name == args["color"]: + series_to_copy = df[col_name] + new_col_name = str(args["color"]) + "additional_col_for_hover" + df[new_col_name] = series_to_copy + args["color"] = new_col_name + elif col_name in path: + series_to_copy = df[col_name] + new_col_name = col_name + "additional_col_for_hover" + path = [new_col_name if x == col_name else x for x in path] + df[new_col_name] = series_to_copy # ------------ Define aggregation functions -------------------------------- def aggfunc_discrete(x): uniques = x.unique() diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 2a43dab7234..8ae1e9ea791 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -164,6 +164,14 @@ def test_sunburst_treemap_with_path(): assert fig.data[0].branchvalues == "total" +def test_sunburst_treemap_with_path_and_hover(): + df = px.data.tips() + fig = px.sunburst( + df, path=["sex", "day", "time", "smoker"], color="smoker", hover_data=["smoker"] + ) + assert "smoker" in fig.data[0].hovertemplate + + def test_sunburst_treemap_with_path_color(): vendors = ["A", "B", "C", "D", "E", "F", "G", "H"] sectors = [ From d63d4bd137cb839ef782ce3d638b89720a9d0654 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 22 Jan 2020 13:39:53 -0500 Subject: [PATCH 32/33] fixed pandas API pb --- packages/python/plotly/plotly/express/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 0009a23d720..8eff25d4220 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1010,10 +1010,10 @@ def build_dataframe(args, attrables, array_attrables): def _check_dataframe_all_leaves(df): df_sorted = df.sort_values(by=list(df.columns)) null_mask = df_sorted.isnull() - null_indices = null_mask.any(axis=1).to_numpy().nonzero()[0] + null_indices = np.nonzero(null_mask.any(axis=1))[0] for null_row_index in null_indices: row = null_mask.iloc[null_row_index] - indices = row.to_numpy().nonzero()[0] + indices = np.nonzero(row)[0] if not row[indices[0] :].all(): raise ValueError( "None entries cannot have not-None children", From 9b217f824c26abd7a2d24e3bf1126e3d76ff3991 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 22 Jan 2020 14:05:04 -0500 Subject: [PATCH 33/33] pandas stuff --- packages/python/plotly/plotly/express/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 8eff25d4220..e43b4beb76b 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1010,10 +1010,10 @@ def build_dataframe(args, attrables, array_attrables): def _check_dataframe_all_leaves(df): df_sorted = df.sort_values(by=list(df.columns)) null_mask = df_sorted.isnull() - null_indices = np.nonzero(null_mask.any(axis=1))[0] + null_indices = np.nonzero(null_mask.any(axis=1).values)[0] for null_row_index in null_indices: row = null_mask.iloc[null_row_index] - indices = np.nonzero(row)[0] + indices = np.nonzero(row.values)[0] if not row[indices[0] :].all(): raise ValueError( "None entries cannot have not-None children",