From ad2c0a52de658892c70f26d315064c5b6b0e70a6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 14 Nov 2024 11:21:26 +0000 Subject: [PATCH 1/6] fix: Setting category_orders was leading to missing data --- .../python/plotly/plotly/express/_core.py | 25 +++++++++++-------- .../tests/test_optional/test_px/test_px.py | 21 ++++++++++++++++ 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index b3bcd096d34..9275dfaa730 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2422,7 +2422,6 @@ def get_groups_and_orders(args, grouper): # figure out orders and what the single group name would be if there were one single_group_name = [] unique_cache = dict() - grp_to_idx = dict() for i, col in enumerate(grouper): if col == one_group: @@ -2440,27 +2439,31 @@ def get_groups_and_orders(args, grouper): else: orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques)) - grp_to_idx = {k: i for i, k in enumerate(orders)} - if len(single_group_name) == len(grouper): # we have a single group, so we can skip all group-by operations! groups = {tuple(single_group_name): df} else: - required_grouper = list(orders.keys()) + required_grouper = [ + key for key in orders if key in grouper and key != one_group + ] + order_mapping = {key: orders[key] for key in required_grouper} grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__()) - sorted_group_names = list(grouped.keys()) - for i, col in reversed(list(enumerate(required_grouper))): - sorted_group_names = sorted( - sorted_group_names, - key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1, - ) + sorted_group_names = sorted( + grouped.keys(), + key=lambda group: [ + order_mapping[key].index(value) if value in order_mapping[key] else -1 + for key, value in zip(required_grouper, group) + ], + ) # calculate the full group_names by inserting "" in the tuple index for one_group groups full_sorted_group_names = [ tuple( [ - "" if col == one_group else sub_group_names[grp_to_idx[col]] + "" + if col == one_group + else sub_group_names[required_grouper.index(col)] for col in grouper ] ) diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py index 8d091df3ae2..76a204c7249 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py @@ -289,6 +289,27 @@ def test_orthogonal_orderings(backend, days, times): assert_orderings(backend, days, days, times, times) +def test_category_order_with_category_as_x(backend): + # https://github.com/plotly/plotly.py/issues/4875 + tips = nw.from_native(px.data.tips(return_type=backend)) + fig = px.bar( + tips, + x="day", + y="total_bill", + color="smoker", + barmode="group", + facet_col="sex", + category_orders={ + "day": ["Thur", "Fri", "Sat", "Sun"], + "smoker": ["Yes", "No"], + "sex": ["Male", "Female"], + }, + ) + assert fig["layout"]["xaxis"]["categoryarray"] == ("Thur", "Fri", "Sat", "Sun") + for trace in fig["data"]: + assert sorted(set(trace["x"])) == ["Fri", "Sat", "Sun", "Thur"] + + def test_permissive_defaults(): msg = "'PxDefaults' object has no attribute 'should_not_work'" with pytest.raises(AttributeError, match=msg): From eadc70b38fe6058ff064502ca21847a446851fce Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:01:14 +0000 Subject: [PATCH 2/6] simplify --- packages/python/plotly/plotly/express/_core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 9275dfaa730..6570ad8846a 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2443,9 +2443,7 @@ def get_groups_and_orders(args, grouper): # we have a single group, so we can skip all group-by operations! groups = {tuple(single_group_name): df} else: - required_grouper = [ - key for key in orders if key in grouper and key != one_group - ] + required_grouper = [key for key in orders if key in grouper] order_mapping = {key: orders[key] for key in required_grouper} grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__()) From 5e6fb93999f2ef7ddfda265b125d4c2167e53c81 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:06:53 +0000 Subject: [PATCH 3/6] it gets simpler --- packages/python/plotly/plotly/express/_core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 6570ad8846a..4b5cddbdd42 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2444,13 +2444,12 @@ def get_groups_and_orders(args, grouper): groups = {tuple(single_group_name): df} else: required_grouper = [key for key in orders if key in grouper] - order_mapping = {key: orders[key] for key in required_grouper} grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__()) sorted_group_names = sorted( grouped.keys(), key=lambda group: [ - order_mapping[key].index(value) if value in order_mapping[key] else -1 + orders[key].index(value) if value in orders[key] else -1 for key, value in zip(required_grouper, group) ], ) From 14a1da6dd087481ef9cdedc3febf5b9aebc84d1f Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Thu, 14 Nov 2024 15:12:44 +0000 Subject: [PATCH 4/6] Update packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py Co-authored-by: Emily KL <4672118+emilykl@users.noreply.github.com> --- .../python/plotly/plotly/tests/test_optional/test_px/test_px.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py index 76a204c7249..5ae751f6663 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py @@ -307,7 +307,7 @@ def test_category_order_with_category_as_x(backend): ) assert fig["layout"]["xaxis"]["categoryarray"] == ("Thur", "Fri", "Sat", "Sun") for trace in fig["data"]: - assert sorted(set(trace["x"])) == ["Fri", "Sat", "Sun", "Thur"] + assert set(trace["x"]) == {"Thur", "Fri", "Sat", "Sun"} def test_permissive_defaults(): From 93a39b6026207cbefbdb4d715e34067e4bacc97a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:19:00 +0000 Subject: [PATCH 5/6] better variable names --- packages/python/plotly/plotly/express/_core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 4b5cddbdd42..aedffad9eca 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2448,9 +2448,9 @@ def get_groups_and_orders(args, grouper): sorted_group_names = sorted( grouped.keys(), - key=lambda group: [ - orders[key].index(value) if value in orders[key] else -1 - for key, value in zip(required_grouper, group) + key=lambda values: [ + orders[group].index(value) if value in orders[group] else -1 + for group, value in zip(required_grouper, values) ], ) From b8b4d29f015b7b84a54656dda74f545e77cfae82 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:20:17 +0000 Subject: [PATCH 6/6] better variable names --- packages/python/plotly/plotly/express/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index aedffad9eca..27376ab037c 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2443,7 +2443,7 @@ def get_groups_and_orders(args, grouper): # we have a single group, so we can skip all group-by operations! groups = {tuple(single_group_name): df} else: - required_grouper = [key for key in orders if key in grouper] + required_grouper = [group for group in orders if group in grouper] grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__()) sorted_group_names = sorted(