Skip to content

Commit a457f0e

Browse files
committed
Consider necessary columns from complex arguments when interchanging dataframes
1 parent 91060d3 commit a457f0e

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1419,9 +1419,17 @@ def build_dataframe(args, constructor):
14191419
else:
14201420
# Save precious resources by only interchanging columns that are
14211421
# actually going to be plotted.
1422-
columns = [
1422+
necessary_columns = [
14231423
i for i in args.values() if isinstance(i, str) and i in columns
14241424
]
1425+
for field in args:
1426+
if field in array_attrables and isinstance(
1427+
args[field], (list, dict)
1428+
):
1429+
necessary_columns.extend(
1430+
[i for i in args[field] if i in columns]
1431+
)
1432+
columns = list(dict.fromkeys(necessary_columns))
14251433
args["data_frame"] = pd.api.interchange.from_dataframe(
14261434
args["data_frame"].select_columns_by_name(columns)
14271435
)

Diff for: packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py

+26
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,32 @@ def test_build_df_from_vaex_and_polars(test_lib):
327327
)
328328

329329

330+
@pytest.mark.skipif(
331+
version.parse(pd.__version__) < version.parse("2.0.2"),
332+
reason="plotly doesn't use a dataframe interchange protocol for pandas < 2.0.2",
333+
)
334+
@pytest.mark.parametrize("test_lib", ["vaex", "polars"])
335+
def test_build_df_with_hover_data_from_vaex_and_polars(test_lib):
336+
if test_lib == "vaex":
337+
import vaex as lib
338+
else:
339+
import polars as lib
340+
341+
# take out the 'species' columns since the vaex implementation does not cover strings yet
342+
iris_pandas = px.data.iris()[["petal_width", "sepal_length", "sepal_width"]]
343+
iris_vaex = lib.from_pandas(iris_pandas)
344+
args = dict(
345+
data_frame=iris_vaex,
346+
x="petal_width",
347+
y="sepal_length",
348+
hover_data=["sepal_width"],
349+
)
350+
out = build_dataframe(args, go.Scatter)
351+
assert_frame_equal(
352+
iris_pandas.reset_index()[out["data_frame"].columns], out["data_frame"]
353+
)
354+
355+
330356
def test_timezones():
331357
df = pd.DataFrame({"date": ["2015-04-04 19:31:30+1:00"], "value": [3]})
332358
df["date"] = pd.to_datetime(df["date"])

0 commit comments

Comments
 (0)