From fa24c1ae4b86ab76db1d212f352f72782cab4238 Mon Sep 17 00:00:00 2001 From: jvdd Date: Mon, 6 Jun 2022 19:35:44 +0200 Subject: [PATCH 1/3] :zap: avoid expensive & unnecessary groupby in px --- .../python/plotly/plotly/express/_core.py | 40 +++++++++++++++++-- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index f4aa02cdc21..33007348a3a 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1955,10 +1955,36 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( args, constructor, trace_patch, layout_patch ) - grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] - grouped = args["data_frame"].groupby(grouper, sort=False) + grouper_ = [x.grouper or one_group for x in grouped_mappings] or [one_group] + + all_same = True # Variable indicating if grouping can be avoided + for g in grouper_: + if g is not one_group: + all_same &= (args["data_frame"][g].nunique() == 1) + if not all_same: break # early stopping if not all the same + + if all_same: + # Do not perform an expensive groupby operation when there are either + # no groups to group by, or when the group has only one (i.e., the same) value + grouper = [g for g in grouper_ if g is not one_group] + assert len(grouper) <= 1 + # -> create orders, sorted_group_names equivalent to those of get_ordings + orders = {g: [args["data_frame"][g].iloc[0]] for g in grouper} + sorted_group_names = [tuple(args["data_frame"][g].iloc[0] for g in orders)] + if len(sorted_group_names): # check for length to support also empty plots + assert len(sorted_group_names) == 1 # should be only for 1 variable + all_same_group_names = list(sorted_group_names[0]) # convert tuple to list + for idx in range(len(grouper_)): + # insert "" in the list when no grouping was used + if grouper_[idx] is one_group: + all_same_group_names.insert(idx, "") + all_same_group_names = tuple(all_same_group_names) # convert list to tuple + sorted_group_names = [all_same_group_names] + else: + grouper = grouper_ + grouped = args["data_frame"].groupby(grouper, sort=False) - orders, sorted_group_names = get_orderings(args, grouper, grouped) + orders, sorted_group_names = get_orderings(args, grouper, grouped) col_labels = [] row_labels = [] @@ -1988,7 +2014,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): trace_name_labels = None facet_col_wrap = args.get("facet_col_wrap", 0) for group_name in sorted_group_names: - group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) + if all_same: + # No expensive get_group operation when all data from the same group + group = args["data_frame"] + else: + group = grouped.get_group( + group_name if len(group_name) > 1 else group_name[0] + ) mapping_labels = OrderedDict() trace_name_labels = OrderedDict() frame_name = "" From 9e93ac89253c0152c5d7588e38c46ac49acb1c78 Mon Sep 17 00:00:00 2001 From: jvdd Date: Mon, 6 Jun 2022 20:24:50 +0200 Subject: [PATCH 2/3] :see_no_evil: fix formatting --- packages/python/plotly/plotly/express/_core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 33007348a3a..3bcbafb5982 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1960,8 +1960,9 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): all_same = True # Variable indicating if grouping can be avoided for g in grouper_: if g is not one_group: - all_same &= (args["data_frame"][g].nunique() == 1) - if not all_same: break # early stopping if not all the same + all_same &= args["data_frame"][g].nunique() == 1 + if not all_same: + break # early stopping if not all the same if all_same: # Do not perform an expensive groupby operation when there are either From 958014aca5cb20632303099b3bda132829b15daf Mon Sep 17 00:00:00 2001 From: jvdd Date: Tue, 7 Jun 2022 10:27:50 +0200 Subject: [PATCH 3/3] :zap: faster check for all same values + refactor code --- packages/python/plotly/plotly/express/_core.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 3bcbafb5982..f5c3acb9c9f 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1957,14 +1957,15 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): ) grouper_ = [x.grouper or one_group for x in grouped_mappings] or [one_group] - all_same = True # Variable indicating if grouping can be avoided + all_same_group = True # variable indicating if grouping can be avoided for g in grouper_: if g is not one_group: - all_same &= args["data_frame"][g].nunique() == 1 - if not all_same: + arr = args["data_frame"][g].values + all_same_group &= (arr[0] == arr).all(axis=0) + if not all_same_group: break # early stopping if not all the same - if all_same: + if all_same_group: # Do not perform an expensive groupby operation when there are either # no groups to group by, or when the group has only one (i.e., the same) value grouper = [g for g in grouper_ if g is not one_group] @@ -1974,13 +1975,12 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): sorted_group_names = [tuple(args["data_frame"][g].iloc[0] for g in orders)] if len(sorted_group_names): # check for length to support also empty plots assert len(sorted_group_names) == 1 # should be only for 1 variable - all_same_group_names = list(sorted_group_names[0]) # convert tuple to list + sorted_group_names = list(sorted_group_names[0]) # convert [tuple] to list for idx in range(len(grouper_)): # insert "" in the list when no grouping was used if grouper_[idx] is one_group: - all_same_group_names.insert(idx, "") - all_same_group_names = tuple(all_same_group_names) # convert list to tuple - sorted_group_names = [all_same_group_names] + sorted_group_names.insert(idx, "") + sorted_group_names = [tuple(sorted_group_names)] # convert list to [tuple] else: grouper = grouper_ grouped = args["data_frame"].groupby(grouper, sort=False) @@ -2015,7 +2015,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): trace_name_labels = None facet_col_wrap = args.get("facet_col_wrap", 0) for group_name in sorted_group_names: - if all_same: + if all_same_group: # No expensive get_group operation when all data from the same group group = args["data_frame"] else: