Skip to content

Commit dd2137b

Browse files
Merge pull request #3767 from jvdd/one_group_short_circuit
♻️ check for all same groups
2 parents b97d197 + 73b3c7c commit dd2137b

File tree

1 file changed

+19
-10
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+19
-10
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -1904,7 +1904,7 @@ def infer_config(args, constructor, trace_patch, layout_patch):
19041904
return trace_specs, grouped_mappings, sizeref, show_colorbar
19051905

19061906

1907-
def get_orderings(args, grouper, grouped):
1907+
def get_orderings(args, grouper, grouped, all_same_group):
19081908
"""
19091909
`orders` is the user-supplied ordering with the remaining data-frame-supplied
19101910
ordering appended if the column is used for grouping. It includes anything the user
@@ -1916,10 +1916,17 @@ def get_orderings(args, grouper, grouped):
19161916
of a single dimension-group
19171917
"""
19181918
orders = {} if "category_orders" not in args else args["category_orders"].copy()
1919+
sorted_group_names = []
19191920

1920-
if _all_one_group(grouper):
1921-
sorted_group_names = [("",) * len(grouper)]
1922-
return orders, sorted_group_names
1921+
if all_same_group:
1922+
for col in grouper:
1923+
if col != one_group:
1924+
single_val = args["data_frame"][col].iloc[0]
1925+
sorted_group_names.append(single_val)
1926+
orders[col] = [single_val]
1927+
else:
1928+
sorted_group_names.append("")
1929+
return orders, [tuple(sorted_group_names)]
19231930

19241931
for col in grouper:
19251932
if col != one_group:
@@ -1929,7 +1936,6 @@ def get_orderings(args, grouper, grouped):
19291936
else:
19301937
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
19311938

1932-
sorted_group_names = []
19331939
for group_name in grouped.groups:
19341940
if len(grouper) == 1:
19351941
group_name = (group_name,)
@@ -1944,10 +1950,12 @@ def get_orderings(args, grouper, grouped):
19441950
return orders, sorted_group_names
19451951

19461952

1947-
def _all_one_group(grouper):
1948-
for g in grouper:
1953+
def _all_same_group(args, grouper):
1954+
for g in set(grouper):
19491955
if g != one_group:
1950-
return False
1956+
arr = args["data_frame"][g].values
1957+
if not (arr[0] == arr).all(axis=0):
1958+
return False
19511959
return True
19521960

19531961

@@ -1968,10 +1976,11 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19681976
)
19691977
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
19701978
grouped = None
1971-
if not _all_one_group(grouper):
1979+
all_same_group = _all_same_group(args, grouper)
1980+
if not all_same_group:
19721981
grouped = args["data_frame"].groupby(grouper, sort=False)
19731982

1974-
orders, sorted_group_names = get_orderings(args, grouper, grouped)
1983+
orders, sorted_group_names = get_orderings(args, grouper, grouped, all_same_group)
19751984

19761985
col_labels = []
19771986
row_labels = []

0 commit comments

Comments
 (0)