Skip to content

Commit b770468

Browse files
authored
Add custom_data argument to px functions (#1764)
1 parent 7ec7a69 commit b770468

File tree

4 files changed

+111
-8
lines changed

4 files changed

+111
-8
lines changed

Diff for: packages/python/plotly/plotly/express/_chart_types.py

+20
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def scatter(
1212
size=None,
1313
hover_name=None,
1414
hover_data=None,
15+
custom_data=None,
1516
text=None,
1617
facet_row=None,
1718
facet_col=None,
@@ -174,6 +175,7 @@ def line(
174175
line_dash=None,
175176
hover_name=None,
176177
hover_data=None,
178+
custom_data=None,
177179
text=None,
178180
facet_row=None,
179181
facet_col=None,
@@ -217,6 +219,7 @@ def area(
217219
color=None,
218220
hover_name=None,
219221
hover_data=None,
222+
custom_data=None,
220223
text=None,
221224
facet_row=None,
222225
facet_col=None,
@@ -262,6 +265,7 @@ def bar(
262265
facet_col=None,
263266
hover_name=None,
264267
hover_data=None,
268+
custom_data=None,
265269
text=None,
266270
error_x=None,
267271
error_x_minus=None,
@@ -368,6 +372,7 @@ def violin(
368372
facet_col=None,
369373
hover_name=None,
370374
hover_data=None,
375+
custom_data=None,
371376
animation_frame=None,
372377
animation_group=None,
373378
category_orders={},
@@ -418,6 +423,7 @@ def box(
418423
facet_col=None,
419424
hover_name=None,
420425
hover_data=None,
426+
custom_data=None,
421427
animation_frame=None,
422428
animation_group=None,
423429
category_orders={},
@@ -463,6 +469,7 @@ def strip(
463469
facet_col=None,
464470
hover_name=None,
465471
hover_data=None,
472+
custom_data=None,
466473
animation_frame=None,
467474
animation_group=None,
468475
category_orders={},
@@ -514,6 +521,7 @@ def scatter_3d(
514521
text=None,
515522
hover_name=None,
516523
hover_data=None,
524+
custom_data=None,
517525
error_x=None,
518526
error_x_minus=None,
519527
error_y=None,
@@ -564,6 +572,7 @@ def line_3d(
564572
line_group=None,
565573
hover_name=None,
566574
hover_data=None,
575+
custom_data=None,
567576
error_x=None,
568577
error_x_minus=None,
569578
error_y=None,
@@ -609,6 +618,7 @@ def scatter_ternary(
609618
text=None,
610619
hover_name=None,
611620
hover_data=None,
621+
custom_data=None,
612622
animation_frame=None,
613623
animation_group=None,
614624
category_orders={},
@@ -646,6 +656,7 @@ def line_ternary(
646656
line_group=None,
647657
hover_name=None,
648658
hover_data=None,
659+
custom_data=None,
649660
text=None,
650661
animation_frame=None,
651662
animation_group=None,
@@ -679,6 +690,7 @@ def scatter_polar(
679690
size=None,
680691
hover_name=None,
681692
hover_data=None,
693+
custom_data=None,
682694
text=None,
683695
animation_frame=None,
684696
animation_group=None,
@@ -721,6 +733,7 @@ def line_polar(
721733
line_dash=None,
722734
hover_name=None,
723735
hover_data=None,
736+
custom_data=None,
724737
line_group=None,
725738
text=None,
726739
animation_frame=None,
@@ -759,6 +772,7 @@ def bar_polar(
759772
color=None,
760773
hover_name=None,
761774
hover_data=None,
775+
custom_data=None,
762776
animation_frame=None,
763777
animation_group=None,
764778
category_orders={},
@@ -798,6 +812,7 @@ def choropleth(
798812
color=None,
799813
hover_name=None,
800814
hover_data=None,
815+
custom_data=None,
801816
size=None,
802817
animation_frame=None,
803818
animation_group=None,
@@ -838,6 +853,7 @@ def scatter_geo(
838853
text=None,
839854
hover_name=None,
840855
hover_data=None,
856+
custom_data=None,
841857
size=None,
842858
animation_frame=None,
843859
animation_group=None,
@@ -882,6 +898,7 @@ def line_geo(
882898
text=None,
883899
hover_name=None,
884900
hover_data=None,
901+
custom_data=None,
885902
line_group=None,
886903
animation_frame=None,
887904
animation_group=None,
@@ -920,6 +937,7 @@ def scatter_mapbox(
920937
text=None,
921938
hover_name=None,
922939
hover_data=None,
940+
custom_data=None,
923941
size=None,
924942
animation_frame=None,
925943
animation_group=None,
@@ -955,6 +973,7 @@ def line_mapbox(
955973
text=None,
956974
hover_name=None,
957975
hover_data=None,
976+
custom_data=None,
958977
line_group=None,
959978
animation_frame=None,
960979
animation_group=None,
@@ -985,6 +1004,7 @@ def scatter_matrix(
9851004
size=None,
9861005
hover_name=None,
9871006
hover_data=None,
1007+
custom_data=None,
9881008
category_orders={},
9891009
labels={},
9901010
color_discrete_sequence=None,

Diff for: packages/python/plotly/plotly/express/_core.py

+46-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .colors import qualitative, sequential
77
import math
88
import pandas
9+
import numpy as np
910

1011
from plotly.subplots import (
1112
make_subplots,
@@ -137,12 +138,35 @@ def make_mapping(args, variable):
137138

138139

139140
def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
140-
141+
"""Populates a dict with arguments to update trace
142+
143+
Parameters
144+
----------
145+
args : dict
146+
args to be used for the trace
147+
trace_spec : NamedTuple
148+
which kind of trace to be used (has constructor, marginal etc.
149+
attributes)
150+
g : pandas DataFrame
151+
data
152+
mapping_labels : dict
153+
to be used for hovertemplate
154+
sizeref : float
155+
marker sizeref
156+
157+
Returns
158+
-------
159+
result : dict
160+
dict to be used to update trace
161+
fit_results : dict
162+
fit information to be used for trendlines
163+
"""
141164
if "line_close" in args and args["line_close"]:
142165
g = g.append(g.iloc[0])
143166
result = trace_spec.trace_patch.copy() or {}
144167
fit_results = None
145168
hover_header = ""
169+
custom_data_len = 0
146170
for k in trace_spec.attrs:
147171
v = args[k]
148172
v_label = get_decorated_label(args, v, k)
@@ -194,7 +218,6 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
194218
elif k == "trendline":
195219
if v in ["ols", "lowess"] and args["x"] and args["y"] and len(g) > 1:
196220
import statsmodels.api as sm
197-
import numpy as np
198221

199222
# sorting is bad but trace_specs with "trendline" have no other attrs
200223
g2 = g.sort_values(by=args["x"])
@@ -231,6 +254,9 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
231254
if error_xy not in result:
232255
result[error_xy] = {}
233256
result[error_xy][arr] = g[v]
257+
elif k == "custom_data":
258+
result["customdata"] = g[v].values
259+
custom_data_len = len(v) # number of custom data columns
234260
elif k == "hover_name":
235261
if trace_spec.constructor not in [
236262
go.Histogram,
@@ -246,10 +272,20 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
246272
go.Histogram2d,
247273
go.Histogram2dContour,
248274
]:
249-
result["customdata"] = g[v].values
250-
for i, col in enumerate(v):
275+
for col in v:
276+
try:
277+
position = args["custom_data"].index(col)
278+
except (ValueError, AttributeError, KeyError):
279+
position = custom_data_len
280+
custom_data_len += 1
281+
if "customdata" in result:
282+
result["customdata"] = np.hstack(
283+
(result["customdata"], g[col].values[:, None])
284+
)
285+
else:
286+
result["customdata"] = g[col].values[:, None]
251287
v_label_col = get_decorated_label(args, col, None)
252-
mapping_labels[v_label_col] = "%%{customdata[%d]}" % i
288+
mapping_labels[v_label_col] = "%%{customdata[%d]}" % (position)
253289
elif k == "color":
254290
if trace_spec.constructor == go.Choropleth:
255291
result["z"] = g[v]
@@ -721,12 +757,13 @@ def apply_default_cascade(args):
721757
def infer_config(args, constructor, trace_patch):
722758
# Declare all supported attributes, across all plot types
723759
attrables = (
724-
["x", "y", "z", "a", "b", "c", "r", "theta", "size"]
725-
+ ["dimensions", "hover_name", "hover_data", "text", "error_x", "error_x_minus"]
760+
["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"]
761+
+ ["custom_data", "hover_name", "hover_data", "text"]
762+
+ ["error_x", "error_x_minus"]
726763
+ ["error_y", "error_y_minus", "error_z", "error_z_minus"]
727764
+ ["lat", "lon", "locations", "animation_group"]
728765
)
729-
array_attrables = ["dimensions", "hover_data"]
766+
array_attrables = ["dimensions", "custom_data", "hover_data"]
730767
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
731768

732769
# Validate that the strings provided as attribute values reference columns
@@ -916,6 +953,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
916953
if constructor_to_use == go.Scatter
917954
else go.Scatterpolargl
918955
)
956+
# Create the trace
919957
trace = constructor_to_use(name=trace_name)
920958
if trace_spec.constructor not in [
921959
go.Parcats,

Diff for: packages/python/plotly/plotly/express/_doc.py

+4
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@
111111
colref_list,
112112
"Values from these columns appear as extra data in the hover tooltip.",
113113
],
114+
custom_data=[
115+
colref_list,
116+
"Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)",
117+
],
114118
text=[colref, "Values from this column appear in the figure as text labels."],
115119
locationmode=[
116120
"(string, one of 'ISO-3', 'USA-states', 'country names')",

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_px.py

+41
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,44 @@ def test_scatter():
1010
assert np.all(fig.data[0].y == iris.sepal_length)
1111
# test defaults
1212
assert fig.data[0].mode == "markers"
13+
14+
15+
def test_custom_data_scatter():
16+
iris = px.data.iris()
17+
# No hover, no custom data
18+
fig = px.scatter(iris, x="sepal_width", y="sepal_length", color="species")
19+
assert fig.data[0].customdata is None
20+
# Hover, no custom data
21+
fig = px.scatter(
22+
iris,
23+
x="sepal_width",
24+
y="sepal_length",
25+
color="species",
26+
hover_data=["petal_length", "petal_width"],
27+
)
28+
for data in fig.data:
29+
assert np.all(np.in1d(data.customdata[:, 1], iris.petal_width))
30+
# Hover and custom data, no repeated arguments
31+
fig = px.scatter(
32+
iris,
33+
x="sepal_width",
34+
y="sepal_length",
35+
hover_data=["petal_length", "petal_width"],
36+
custom_data=["species_id", "species"],
37+
)
38+
assert np.all(fig.data[0].customdata[:, 0] == iris.species_id)
39+
assert fig.data[0].customdata.shape[1] == 4
40+
# Hover and custom data, with repeated arguments
41+
fig = px.scatter(
42+
iris,
43+
x="sepal_width",
44+
y="sepal_length",
45+
hover_data=["petal_length", "petal_width", "species_id"],
46+
custom_data=["species_id", "species"],
47+
)
48+
assert np.all(fig.data[0].customdata[:, 0] == iris.species_id)
49+
assert fig.data[0].customdata.shape[1] == 4
50+
assert (
51+
fig.data[0].hovertemplate
52+
== "sepal_width=%{x}<br>sepal_length=%{y}<br>petal_length=%{customdata[2]}<br>petal_width=%{customdata[3]}<br>species_id=%{customdata[0]}"
53+
)

0 commit comments

Comments
 (0)