diff --git a/CHANGELOG.md b/CHANGELOG.md index d015cdd2dd4..83e7465f8bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ This project adheres to [Semantic Versioning](http://semver.org/). - `pattern_shape` options now available in `px.timeline()` [#3774](https://github.com/plotly/plotly.py/pull/3774) - `facet_*` and `category_orders` now available in `px.pie()` [#3775](https://github.com/plotly/plotly.py/pull/3775) +### Performance + + - `px` methods no longer call `groupby` on the input dataframe when the result would be a single group, and no longer groups by a lambda, for significant speedups [#3765](https://github.com/plotly/plotly.py/pull/3765) + ### Updated - Allow non-string extras in `flaglist` attributes, to support upcoming changes to `ax.automargin` in plotly.js [plotly.js#6193](https://github.com/plotly/plotly.js/pull/6193), [#3749](https://github.com/plotly/plotly.py/pull/3749) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 99a4c2a3701..dd2d53be1ed 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1920,40 +1920,66 @@ def infer_config(args, constructor, trace_patch, layout_patch): return trace_specs, grouped_mappings, sizeref, show_colorbar -def get_orderings(args, grouper, grouped): +def get_groups_and_orders(args, grouper): """ `orders` is the user-supplied ordering with the remaining data-frame-supplied ordering appended if the column is used for grouping. It includes anything the user gave, for any variable, including values not present in the dataset. It's a dict where the keys are e.g. "x" or "color" - `sorted_group_names` is the set of groups, ordered by the order above. It's a list - of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name + `groups` is the dicts of groups, ordered by the order above. Its keys are + tuples like [("value1", ""), ("value2", "")] where each tuple contains the name of a single dimension-group """ - orders = {} if "category_orders" not in args else args["category_orders"].copy() + + # figure out orders and what the single group name would be if there were one + single_group_name = [] + unique_cache = dict() for col in grouper: - if col != one_group: - uniques = list(args["data_frame"][col].unique()) + if col == one_group: + single_group_name.append("") + else: + if col not in unique_cache: + unique_cache[col] = list(args["data_frame"][col].unique()) + uniques = unique_cache[col] + if len(uniques) == 1: + single_group_name.append(uniques[0]) if col not in orders: orders[col] = uniques else: orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques)) + df = args["data_frame"] + if len(single_group_name) == len(grouper): + # we have a single group, so we can skip all group-by operations! + groups = {tuple(single_group_name): df} + else: + required_grouper = [g for g in grouper if g != one_group] + grouped = df.groupby(required_grouper, sort=False) # skip one_group groupers + group_indices = grouped.indices + sorted_group_names = [ + g if len(required_grouper) != 1 else (g,) for g in group_indices + ] - sorted_group_names = [] - for group_name in grouped.groups: - if len(grouper) == 1: - group_name = (group_name,) - sorted_group_names.append(group_name) - - for i, col in reversed(list(enumerate(grouper))): - if col != one_group: + for i, col in reversed(list(enumerate(required_grouper))): sorted_group_names = sorted( sorted_group_names, key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1, ) - return orders, sorted_group_names + + # calculate the full group_names by inserting "" in the tuple index for one_group groups + full_sorted_group_names = [list(t) for t in sorted_group_names] + for i, col in enumerate(grouper): + if col == one_group: + for g in full_sorted_group_names: + g.insert(i, "") + full_sorted_group_names = [tuple(g) for g in full_sorted_group_names] + + groups = { + sf: grouped.get_group(s if len(s) > 1 else s[0]) + for sf, s in zip(full_sorted_group_names, sorted_group_names) + } + return groups, orders def make_figure(args, constructor, trace_patch=None, layout_patch=None): @@ -1974,9 +2000,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): args, constructor, trace_patch, layout_patch ) grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] - grouped = args["data_frame"].groupby(grouper, sort=False) - - orders, sorted_group_names = get_orderings(args, grouper, grouped) + groups, orders = get_groups_and_orders(args, grouper) col_labels = [] row_labels = [] @@ -2005,8 +2029,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): trendline_rows = [] trace_name_labels = None facet_col_wrap = args.get("facet_col_wrap", 0) - for group_name in sorted_group_names: - group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) + for group_name, group in groups.items(): mapping_labels = OrderedDict() trace_name_labels = OrderedDict() frame_name = "" @@ -2224,6 +2247,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): fig.update_layout(layout_patch) if "template" in args and args["template"] is not None: fig.update_layout(template=args["template"], overwrite=True) + for f in frame_list: + f["name"] = str(f["name"]) fig.frames = frame_list if len(frames) > 1 else [] if args.get("trendline") and args.get("trendline_scope", "trace") == "overall":