Skip to content

Commit 7f2ce44

Browse files
committed
Merge remote-tracking branch 'upstream/master' into avoid-iter-row
2 parents 52d562a + 817fef7 commit 7f2ce44

File tree

3 files changed

+45
-24
lines changed

3 files changed

+45
-24
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,6 @@ def set_cartesian_axis_opts(args, axis, letter, orders):
652652

653653

654654
def configure_cartesian_marginal_axes(args, fig, orders):
655-
if "histogram" in [args["marginal_x"], args["marginal_y"]]:
656-
fig.layout["barmode"] = "overlay"
657-
658655
nrows = len(fig._grid_ref)
659656
ncols = len(fig._grid_ref[0])
660657

@@ -1489,25 +1486,21 @@ def build_dataframe(args, constructor):
14891486
# PySpark to pandas.
14901487
is_pd_like = False
14911488

1492-
# Flag that indicates if data_frame requires to be converted to arrow via the
1493-
# dataframe interchange protocol.
1494-
# True if Ibis, DuckDB, Vaex or implements __dataframe__
1489+
# Flag that indicates if data_frame needs to be converted to PyArrow.
1490+
# True if Ibis, DuckDB, Vaex, or implements __dataframe__
14951491
needs_interchanging = False
14961492

14971493
# If data_frame is provided, we parse it into a narwhals DataFrame, while accounting
14981494
# for compatibility with pandas specific paths (e.g. Index/MultiIndex case).
14991495
if df_provided:
1500-
15011496
# data_frame is pandas-like DataFrame (pandas, modin.pandas, cudf)
15021497
if nw.dependencies.is_pandas_like_dataframe(args["data_frame"]):
1503-
15041498
columns = args["data_frame"].columns # This can be multi index
15051499
args["data_frame"] = nw.from_native(args["data_frame"], eager_only=True)
15061500
is_pd_like = True
15071501

15081502
# data_frame is pandas-like Series (pandas, modin.pandas, cudf)
15091503
elif nw.dependencies.is_pandas_like_series(args["data_frame"]):
1510-
15111504
args["data_frame"] = nw.from_native(
15121505
args["data_frame"], series_only=True
15131506
).to_frame()
@@ -1993,7 +1986,6 @@ def process_dataframe_hierarchy(args):
19931986

19941987
if args["color"]:
19951988
if discrete_color:
1996-
19971989
discrete_aggs.append(args["color"])
19981990
agg_f[args["color"]] = nw.col(args["color"]).max()
19991991
agg_f[f'{args["color"]}{n_unique_token}'] = (
@@ -2048,7 +2040,6 @@ def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFra
20482040
).drop([f"{col}{n_unique_token}" for col in discrete_aggs])
20492041

20502042
for i, level in enumerate(path):
2051-
20522043
dfg = (
20532044
df.group_by(path[i:], drop_null_keys=True)
20542045
.agg(**agg_f)
@@ -2425,7 +2416,6 @@ def get_groups_and_orders(args, grouper):
24252416
# figure out orders and what the single group name would be if there were one
24262417
single_group_name = []
24272418
unique_cache = dict()
2428-
grp_to_idx = dict()
24292419

24302420
for i, col in enumerate(grouper):
24312421
if col == one_group:
@@ -2443,27 +2433,28 @@ def get_groups_and_orders(args, grouper):
24432433
else:
24442434
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
24452435

2446-
grp_to_idx = {k: i for i, k in enumerate(orders)}
2447-
24482436
if len(single_group_name) == len(grouper):
24492437
# we have a single group, so we can skip all group-by operations!
24502438
groups = {tuple(single_group_name): df}
24512439
else:
2452-
required_grouper = list(orders.keys())
2440+
required_grouper = [group for group in orders if group in grouper]
24532441
grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__())
2454-
sorted_group_names = list(grouped.keys())
24552442

2456-
for i, col in reversed(list(enumerate(required_grouper))):
2457-
sorted_group_names = sorted(
2458-
sorted_group_names,
2459-
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
2460-
)
2443+
sorted_group_names = sorted(
2444+
grouped.keys(),
2445+
key=lambda values: [
2446+
orders[group].index(value) if value in orders[group] else -1
2447+
for group, value in zip(required_grouper, values)
2448+
],
2449+
)
24612450

24622451
# calculate the full group_names by inserting "" in the tuple index for one_group groups
24632452
full_sorted_group_names = [
24642453
tuple(
24652454
[
2466-
"" if col == one_group else sub_group_names[grp_to_idx[col]]
2455+
""
2456+
if col == one_group
2457+
else sub_group_names[required_grouper.index(col)]
24672458
for col in grouper
24682459
]
24692460
)
@@ -2490,6 +2481,10 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
24902481
constructor = go.Bar
24912482
args = process_dataframe_timeline(args)
24922483

2484+
# If we have marginal histograms, set barmode to "overlay"
2485+
if "histogram" in [args.get("marginal_x"), args.get("marginal_y")]:
2486+
layout_patch["barmode"] = "overlay"
2487+
24932488
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
24942489
args, constructor, trace_patch, layout_patch
24952490
)
@@ -2561,7 +2556,12 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
25612556
legendgroup=trace_name,
25622557
showlegend=(trace_name != "" and trace_name not in trace_names),
25632558
)
2564-
if trace_spec.constructor in [go.Bar, go.Violin, go.Box, go.Histogram]:
2559+
2560+
# Set 'offsetgroup' only in group barmode (or if no barmode is set)
2561+
barmode = layout_patch.get("barmode")
2562+
if trace_spec.constructor in [go.Bar, go.Box, go.Violin, go.Histogram] and (
2563+
barmode == "group" or barmode is None
2564+
):
25652565
trace.update(alignmentgroup=True, offsetgroup=trace_name)
25662566
trace_names.add(trace_name)
25672567

Diff for: packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py

+21
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,27 @@ def test_orthogonal_orderings(backend, days, times):
289289
assert_orderings(backend, days, days, times, times)
290290

291291

292+
def test_category_order_with_category_as_x(backend):
293+
# https://github.com/plotly/plotly.py/issues/4875
294+
tips = nw.from_native(px.data.tips(return_type=backend))
295+
fig = px.bar(
296+
tips,
297+
x="day",
298+
y="total_bill",
299+
color="smoker",
300+
barmode="group",
301+
facet_col="sex",
302+
category_orders={
303+
"day": ["Thur", "Fri", "Sat", "Sun"],
304+
"smoker": ["Yes", "No"],
305+
"sex": ["Male", "Female"],
306+
},
307+
)
308+
assert fig["layout"]["xaxis"]["categoryarray"] == ("Thur", "Fri", "Sat", "Sun")
309+
for trace in fig["data"]:
310+
assert set(trace["x"]) == {"Thur", "Fri", "Sat", "Sun"}
311+
312+
292313
def test_permissive_defaults():
293314
msg = "'PxDefaults' object has no attribute 'should_not_work'"
294315
with pytest.raises(AttributeError, match=msg):

Diff for: packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def np_inf():
9595
columns=["col 1"], data=[1, 2, 3, dt(2014, 1, 5), pd.NaT, np_nan(), np_inf()]
9696
)
9797

98-
rng = pd.date_range("1/1/2011", periods=2, freq="H")
98+
rng = pd.date_range("1/1/2011", periods=2, freq="h")
9999
ts = pd.Series([1.5, 2.5], index=rng)
100100

101101

0 commit comments

Comments
 (0)