Skip to content

Commit 067d4b0

Browse files
optimize group access
1 parent 3ae0645 commit 067d4b0

File tree

1 file changed

+16
-22
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+16
-22
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+16-22
Original file line numberDiff line numberDiff line change
@@ -1904,44 +1904,42 @@ def infer_config(args, constructor, trace_patch, layout_patch):
19041904
return trace_specs, grouped_mappings, sizeref, show_colorbar
19051905

19061906

1907-
def get_orderings(args, grouper):
1907+
def get_groups_and_orders(args, grouper):
19081908
"""
19091909
`orders` is the user-supplied ordering with the remaining data-frame-supplied
19101910
ordering appended if the column is used for grouping. It includes anything the user
19111911
gave, for any variable, including values not present in the dataset. It's a dict
19121912
where the keys are e.g. "x" or "color"
19131913
1914-
`sorted_group_names` is the set of groups, ordered by the order above. It's a list
1915-
of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
1914+
`groups` is the dicts of groups, ordered by the order above. Its keys are
1915+
tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
19161916
of a single dimension-group
19171917
"""
19181918
orders = {} if "category_orders" not in args else args["category_orders"].copy()
19191919

19201920
# figure out orders and what the single group name would be if there were one
19211921
single_group_name = []
1922+
unique_cache = dict()
19221923
for col in grouper:
19231924
if col == one_group:
19241925
single_group_name.append("")
19251926
else:
1926-
uniques = list(args["data_frame"][col].unique())
1927+
if col not in unique_cache:
1928+
unique_cache[col] = list(args["data_frame"][col].unique())
1929+
uniques = unique_cache[col]
19271930
if len(uniques) == 1:
19281931
single_group_name.append(uniques[0])
19291932
if col not in orders:
19301933
orders[col] = uniques
19311934
else:
19321935
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
1933-
1936+
df = args["data_frame"]
19341937
if len(single_group_name) == len(grouper):
19351938
# we have a single group, so we can skip all group-by operations!
1936-
grouped = None
1937-
sorted_group_names = [tuple(single_group_name)]
1939+
groups = {tuple(single_group_name): df}
19381940
else:
1939-
grouped = args["data_frame"].groupby(grouper, sort=False)
1940-
sorted_group_names = []
1941-
for group_name in grouped.groups:
1942-
if len(grouper) == 1:
1943-
group_name = (group_name,)
1944-
sorted_group_names.append(group_name)
1941+
group_indices = df.groupby(grouper, sort=False).indices
1942+
sorted_group_names = [g if len(grouper) != 1 else (g,) for g in group_indices]
19451943

19461944
for i, col in reversed(list(enumerate(grouper))):
19471945
if col != one_group:
@@ -1951,7 +1949,9 @@ def get_orderings(args, grouper):
19511949
if g[i] in orders[col]
19521950
else -1,
19531951
)
1954-
return grouped, orders, sorted_group_names
1952+
1953+
groups = {s: df.iloc[group_indices[s]] for s in sorted_group_names}
1954+
return groups, orders
19551955

19561956

19571957
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
@@ -1970,7 +1970,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19701970
args, constructor, trace_patch, layout_patch
19711971
)
19721972
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
1973-
grouped, orders, sorted_group_names = get_orderings(args, grouper)
1973+
groups, orders = get_groups_and_orders(args, grouper)
19741974

19751975
col_labels = []
19761976
row_labels = []
@@ -1999,13 +1999,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19991999
trendline_rows = []
20002000
trace_name_labels = None
20012001
facet_col_wrap = args.get("facet_col_wrap", 0)
2002-
for group_name in sorted_group_names:
2003-
if grouped is not None:
2004-
group = grouped.get_group(
2005-
group_name if len(group_name) > 1 else group_name[0]
2006-
)
2007-
else:
2008-
group = args["data_frame"]
2002+
for group_name, group in groups.items():
20092003
mapping_labels = OrderedDict()
20102004
trace_name_labels = OrderedDict()
20112005
frame_name = ""

0 commit comments

Comments
 (0)