diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 21c0ca03cc1..d1f8854f6df 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1136,11 +1136,19 @@ def infer_config(args, constructor, trace_patch): def get_orderings(args, grouper, grouped): """ `orders` is the user-supplied ordering (with the remaining data-frame-supplied - ordering appended if the column is used for grouping) + ordering appended if the column is used for grouping). It includes anything the user + gave, for any variable, including values not present in the dataset. It is used + downstream to set e.g. `categoryarray` for cartesian axes + `group_names` is the set of groups, ordered by the order above + + `group_values` is a subset of `orders` in both keys and values. It contains a key + for every grouped mapping and its values are the sorted *data* values for these + mappings. """ orders = {} if "category_orders" not in args else args["category_orders"].copy() group_names = [] + group_values = {} for group_name in grouped.groups: if len(grouper) == 1: group_name = (group_name,) @@ -1154,6 +1162,7 @@ def get_orderings(args, grouper, grouped): for val in uniques: if val not in orders[col]: orders[col].append(val) + group_values[col] = sorted(uniques, key=orders[col].index) for i, col in reversed(list(enumerate(grouper))): if col != one_group: @@ -1162,7 +1171,7 @@ def get_orderings(args, grouper, grouped): key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1, ) - return orders, group_names + return orders, group_names, group_values def make_figure(args, constructor, trace_patch={}, layout_patch={}): @@ -1174,7 +1183,24 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] grouped = args["data_frame"].groupby(grouper, sort=False) - orders, sorted_group_names = get_orderings(args, grouper, grouped) + orders, sorted_group_names, sorted_group_values = get_orderings( + args, grouper, grouped + ) + + col_labels = [] + row_labels = [] + + for m in grouped_mappings: + if m.grouper: + if m.facet == "col": + prefix = get_label(args, args["facet_col"]) + "=" + col_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]] + if m.facet == "row": + prefix = get_label(args, args["facet_row"]) + "=" + row_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]] + for val in sorted_group_values[m.grouper]: + if val not in m.val_map: + m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)] subplot_type = _subplot_type_for_trace_type(constructor().type) @@ -1182,8 +1208,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): frames = OrderedDict() trendline_rows = [] nrows = ncols = 1 - col_labels = [] - row_labels = [] trace_name_labels = None for group_name in sorted_group_names: group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) @@ -1281,10 +1305,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): # Find row for trace, handling facet_row and marginal_x if m.facet == "row": row = m.val_map[val] - if args["facet_row"] and len(row_labels) < row: - row_labels.append( - get_label(args, args["facet_row"]) + "=" + str(val) - ) else: if ( bool(args.get("marginal_x", False)) @@ -1298,10 +1318,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): # Find col for trace, handling facet_col and marginal_y if m.facet == "col": col = m.val_map[val] - if args["facet_col"] and len(col_labels) < col: - col_labels.append( - get_label(args, args["facet_col"]) + "=" + str(val) - ) if facet_col_wrap: # assumes no facet_row, no marginals row = 1 + ((col - 1) // facet_col_wrap) col = 1 + ((col - 1) % facet_col_wrap) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py index 99284e02ba2..9ace6a7b4a9 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px.py @@ -182,3 +182,53 @@ def test_px_templates(): assert fig.layout.xaxis3.showgrid is None assert fig.layout.yaxis2.showgrid assert fig.layout.yaxis3.showgrid + + +def test_orthogonal_orderings(): + from itertools import permutations + + df = px.data.tips() + + symbol_sequence = ["circle", "diamond", "square", "cross"] + color_sequence = ["red", "blue"] + + def assert_orderings(days_order, days_check, times_order, times_check): + fig = px.scatter( + df, + x="total_bill", + y="tip", + facet_row="time", + facet_col="day", + color="time", + symbol="day", + symbol_sequence=symbol_sequence, + color_discrete_sequence=color_sequence, + category_orders=dict(day=days_order, time=times_order), + ) + + for col in range(len(days_check)): + for trace in fig.select_traces(col=col + 1): + assert days_check[col] in trace.hovertemplate + + for row in range(len(times_check)): + for trace in fig.select_traces(row=2 - row): + assert times_check[row] in trace.hovertemplate + + for trace in fig.data: + for i, day in enumerate(days_check): + if day in trace.name: + assert trace.marker.symbol == symbol_sequence[i] + for i, time in enumerate(times_check): + if time in trace.name: + assert trace.marker.color == color_sequence[i] + + assert_orderings( + ["x", "Sun", "Sat", "y", "Fri", "z"], # add extra noise, missing Thur + ["Sun", "Sat", "Fri", "Thur"], # Thur is at the back + ["a", "Lunch", "b"], # add extra noise, missing Dinner + ["Lunch", "Dinner"], # Dinner is at the back + ) + + for days in permutations(df["day"].unique()): + for times in permutations(df["time"].unique()): + assert_orderings(days, days, times, times)