Skip to content

Commit b97d197

Browse files
more correct optimization
1 parent 9941749 commit b97d197

File tree

1 file changed

+23
-17
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+23
-17
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -1898,7 +1898,6 @@ def infer_config(args, constructor, trace_patch, layout_patch):
18981898

18991899
# Create grouped mappings
19001900
grouped_mappings = [make_mapping(args, a) for a in grouped_attrs]
1901-
grouped_mappings = [x for x in grouped_mappings if x.grouper]
19021901

19031902
# Create trace specs
19041903
trace_specs = make_trace_spec(args, constructor, attrs, trace_patch)
@@ -1918,16 +1917,17 @@ def get_orderings(args, grouper, grouped):
19181917
"""
19191918
orders = {} if "category_orders" not in args else args["category_orders"].copy()
19201919

1921-
if grouper == [one_group]:
1922-
sorted_group_names = [("",)]
1920+
if _all_one_group(grouper):
1921+
sorted_group_names = [("",) * len(grouper)]
19231922
return orders, sorted_group_names
19241923

19251924
for col in grouper:
1926-
uniques = list(args["data_frame"][col].unique())
1927-
if col not in orders:
1928-
orders[col] = uniques
1929-
else:
1930-
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
1925+
if col != one_group:
1926+
uniques = list(args["data_frame"][col].unique())
1927+
if col not in orders:
1928+
orders[col] = uniques
1929+
else:
1930+
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
19311931

19321932
sorted_group_names = []
19331933
for group_name in grouped.groups:
@@ -1936,13 +1936,21 @@ def get_orderings(args, grouper, grouped):
19361936
sorted_group_names.append(group_name)
19371937

19381938
for i, col in reversed(list(enumerate(grouper))):
1939-
sorted_group_names = sorted(
1940-
sorted_group_names,
1941-
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
1942-
)
1939+
if col != one_group:
1940+
sorted_group_names = sorted(
1941+
sorted_group_names,
1942+
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
1943+
)
19431944
return orders, sorted_group_names
19441945

19451946

1947+
def _all_one_group(grouper):
1948+
for g in grouper:
1949+
if g != one_group:
1950+
return False
1951+
return True
1952+
1953+
19461954
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19471955
trace_patch = trace_patch or {}
19481956
layout_patch = layout_patch or {}
@@ -1958,12 +1966,10 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19581966
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
19591967
args, constructor, trace_patch, layout_patch
19601968
)
1961-
if len(grouped_mappings):
1962-
grouper = [x.grouper for x in grouped_mappings]
1969+
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
1970+
grouped = None
1971+
if not _all_one_group(grouper):
19631972
grouped = args["data_frame"].groupby(grouper, sort=False)
1964-
else:
1965-
grouper = [one_group]
1966-
grouped = None
19671973

19681974
orders, sorted_group_names = get_orderings(args, grouper, grouped)
19691975

0 commit comments

Comments
 (0)