Skip to content

Commit 6640d8f

Browse files
committed
♻️ check for all same groups
1 parent b97d197 commit 6640d8f

File tree

1 file changed

+10
-7
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+10
-7
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1904,7 +1904,7 @@ def infer_config(args, constructor, trace_patch, layout_patch):
19041904
return trace_specs, grouped_mappings, sizeref, show_colorbar
19051905

19061906

1907-
def get_orderings(args, grouper, grouped):
1907+
def get_orderings(args, grouper, grouped, all_same_group):
19081908
"""
19091909
`orders` is the user-supplied ordering with the remaining data-frame-supplied
19101910
ordering appended if the column is used for grouping. It includes anything the user
@@ -1917,7 +1917,7 @@ def get_orderings(args, grouper, grouped):
19171917
"""
19181918
orders = {} if "category_orders" not in args else args["category_orders"].copy()
19191919

1920-
if _all_one_group(grouper):
1920+
if all_same_group:
19211921
sorted_group_names = [("",) * len(grouper)]
19221922
return orders, sorted_group_names
19231923

@@ -1944,10 +1944,12 @@ def get_orderings(args, grouper, grouped):
19441944
return orders, sorted_group_names
19451945

19461946

1947-
def _all_one_group(grouper):
1948-
for g in grouper:
1947+
def _all_same_group(args, grouper):
1948+
for g in set(grouper):
19491949
if g != one_group:
1950-
return False
1950+
arr = args["data_frame"][g].values
1951+
if not (arr[0] == arr).all(axis=0):
1952+
return False
19511953
return True
19521954

19531955

@@ -1968,10 +1970,11 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19681970
)
19691971
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
19701972
grouped = None
1971-
if not _all_one_group(grouper):
1973+
all_same_group = _all_same_group(args, grouper)
1974+
if not all_same_group:
19721975
grouped = args["data_frame"].groupby(grouper, sort=False)
19731976

1974-
orders, sorted_group_names = get_orderings(args, grouper, grouped)
1977+
orders, sorted_group_names = get_orderings(args, grouper, grouped, all_same_group)
19751978

19761979
col_labels = []
19771980
row_labels = []

0 commit comments

Comments
 (0)