Skip to content

Add custom_data argument to px functions #1764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Sep 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions packages/python/plotly/plotly/express/_chart_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def scatter(
size=None,
hover_name=None,
hover_data=None,
custom_data=None,
text=None,
facet_row=None,
facet_col=None,
Expand Down Expand Up @@ -174,6 +175,7 @@ def line(
line_dash=None,
hover_name=None,
hover_data=None,
custom_data=None,
text=None,
facet_row=None,
facet_col=None,
Expand Down Expand Up @@ -217,6 +219,7 @@ def area(
color=None,
hover_name=None,
hover_data=None,
custom_data=None,
text=None,
facet_row=None,
facet_col=None,
Expand Down Expand Up @@ -262,6 +265,7 @@ def bar(
facet_col=None,
hover_name=None,
hover_data=None,
custom_data=None,
text=None,
error_x=None,
error_x_minus=None,
Expand Down Expand Up @@ -368,6 +372,7 @@ def violin(
facet_col=None,
hover_name=None,
hover_data=None,
custom_data=None,
animation_frame=None,
animation_group=None,
category_orders={},
Expand Down Expand Up @@ -418,6 +423,7 @@ def box(
facet_col=None,
hover_name=None,
hover_data=None,
custom_data=None,
animation_frame=None,
animation_group=None,
category_orders={},
Expand Down Expand Up @@ -463,6 +469,7 @@ def strip(
facet_col=None,
hover_name=None,
hover_data=None,
custom_data=None,
animation_frame=None,
animation_group=None,
category_orders={},
Expand Down Expand Up @@ -514,6 +521,7 @@ def scatter_3d(
text=None,
hover_name=None,
hover_data=None,
custom_data=None,
error_x=None,
error_x_minus=None,
error_y=None,
Expand Down Expand Up @@ -564,6 +572,7 @@ def line_3d(
line_group=None,
hover_name=None,
hover_data=None,
custom_data=None,
error_x=None,
error_x_minus=None,
error_y=None,
Expand Down Expand Up @@ -609,6 +618,7 @@ def scatter_ternary(
text=None,
hover_name=None,
hover_data=None,
custom_data=None,
animation_frame=None,
animation_group=None,
category_orders={},
Expand Down Expand Up @@ -646,6 +656,7 @@ def line_ternary(
line_group=None,
hover_name=None,
hover_data=None,
custom_data=None,
text=None,
animation_frame=None,
animation_group=None,
Expand Down Expand Up @@ -679,6 +690,7 @@ def scatter_polar(
size=None,
hover_name=None,
hover_data=None,
custom_data=None,
text=None,
animation_frame=None,
animation_group=None,
Expand Down Expand Up @@ -721,6 +733,7 @@ def line_polar(
line_dash=None,
hover_name=None,
hover_data=None,
custom_data=None,
line_group=None,
text=None,
animation_frame=None,
Expand Down Expand Up @@ -759,6 +772,7 @@ def bar_polar(
color=None,
hover_name=None,
hover_data=None,
custom_data=None,
animation_frame=None,
animation_group=None,
category_orders={},
Expand Down Expand Up @@ -798,6 +812,7 @@ def choropleth(
color=None,
hover_name=None,
hover_data=None,
custom_data=None,
size=None,
animation_frame=None,
animation_group=None,
Expand Down Expand Up @@ -838,6 +853,7 @@ def scatter_geo(
text=None,
hover_name=None,
hover_data=None,
custom_data=None,
size=None,
animation_frame=None,
animation_group=None,
Expand Down Expand Up @@ -882,6 +898,7 @@ def line_geo(
text=None,
hover_name=None,
hover_data=None,
custom_data=None,
line_group=None,
animation_frame=None,
animation_group=None,
Expand Down Expand Up @@ -920,6 +937,7 @@ def scatter_mapbox(
text=None,
hover_name=None,
hover_data=None,
custom_data=None,
size=None,
animation_frame=None,
animation_group=None,
Expand Down Expand Up @@ -955,6 +973,7 @@ def line_mapbox(
text=None,
hover_name=None,
hover_data=None,
custom_data=None,
line_group=None,
animation_frame=None,
animation_group=None,
Expand Down Expand Up @@ -985,6 +1004,7 @@ def scatter_matrix(
size=None,
hover_name=None,
hover_data=None,
custom_data=None,
category_orders={},
labels={},
color_discrete_sequence=None,
Expand Down
54 changes: 46 additions & 8 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .colors import qualitative, sequential
import math
import pandas
import numpy as np

from plotly.subplots import (
make_subplots,
Expand Down Expand Up @@ -137,12 +138,35 @@ def make_mapping(args, variable):


def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):

"""Populates a dict with arguments to update trace

Parameters
----------
args : dict
args to be used for the trace
trace_spec : NamedTuple
which kind of trace to be used (has constructor, marginal etc.
attributes)
g : pandas DataFrame
data
mapping_labels : dict
to be used for hovertemplate
sizeref : float
marker sizeref

Returns
-------
result : dict
dict to be used to update trace
fit_results : dict
fit information to be used for trendlines
"""
if "line_close" in args and args["line_close"]:
g = g.append(g.iloc[0])
result = trace_spec.trace_patch.copy() or {}
fit_results = None
hover_header = ""
custom_data_len = 0
for k in trace_spec.attrs:
v = args[k]
v_label = get_decorated_label(args, v, k)
Expand Down Expand Up @@ -194,7 +218,6 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
elif k == "trendline":
if v in ["ols", "lowess"] and args["x"] and args["y"] and len(g) > 1:
import statsmodels.api as sm
import numpy as np

# sorting is bad but trace_specs with "trendline" have no other attrs
g2 = g.sort_values(by=args["x"])
Expand Down Expand Up @@ -231,6 +254,9 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
if error_xy not in result:
result[error_xy] = {}
result[error_xy][arr] = g[v]
elif k == "custom_data":
result["customdata"] = g[v].values
custom_data_len = len(v) # number of custom data columns
elif k == "hover_name":
if trace_spec.constructor not in [
go.Histogram,
Expand All @@ -246,10 +272,20 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
go.Histogram2d,
go.Histogram2dContour,
]:
result["customdata"] = g[v].values
for i, col in enumerate(v):
for col in v:
try:
position = args["custom_data"].index(col)
except (ValueError, AttributeError, KeyError):
position = custom_data_len
custom_data_len += 1
if "customdata" in result:
result["customdata"] = np.hstack(
(result["customdata"], g[col].values[:, None])
)
else:
result["customdata"] = g[col].values[:, None]
v_label_col = get_decorated_label(args, col, None)
mapping_labels[v_label_col] = "%%{customdata[%d]}" % i
mapping_labels[v_label_col] = "%%{customdata[%d]}" % (position)
elif k == "color":
if trace_spec.constructor == go.Choropleth:
result["z"] = g[v]
Expand Down Expand Up @@ -721,12 +757,13 @@ def apply_default_cascade(args):
def infer_config(args, constructor, trace_patch):
# Declare all supported attributes, across all plot types
attrables = (
["x", "y", "z", "a", "b", "c", "r", "theta", "size"]
+ ["dimensions", "hover_name", "hover_data", "text", "error_x", "error_x_minus"]
["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"]
+ ["custom_data", "hover_name", "hover_data", "text"]
+ ["error_x", "error_x_minus"]
+ ["error_y", "error_y_minus", "error_z", "error_z_minus"]
+ ["lat", "lon", "locations", "animation_group"]
)
array_attrables = ["dimensions", "hover_data"]
array_attrables = ["dimensions", "custom_data", "hover_data"]
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]

# Validate that the strings provided as attribute values reference columns
Expand Down Expand Up @@ -916,6 +953,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
if constructor_to_use == go.Scatter
else go.Scatterpolargl
)
# Create the trace
trace = constructor_to_use(name=trace_name)
if trace_spec.constructor not in [
go.Parcats,
Expand Down
4 changes: 4 additions & 0 deletions packages/python/plotly/plotly/express/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@
colref_list,
"Values from these columns appear as extra data in the hover tooltip.",
],
custom_data=[
colref_list,
"Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)",
],
text=[colref, "Values from this column appear in the figure as text labels."],
locationmode=[
"(string, one of 'ISO-3', 'USA-states', 'country names')",
Expand Down
41 changes: 41 additions & 0 deletions packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,44 @@ def test_scatter():
assert np.all(fig.data[0].y == iris.sepal_length)
# test defaults
assert fig.data[0].mode == "markers"


def test_custom_data_scatter():
iris = px.data.iris()
# No hover, no custom data
fig = px.scatter(iris, x="sepal_width", y="sepal_length", color="species")
assert fig.data[0].customdata is None
# Hover, no custom data
fig = px.scatter(
iris,
x="sepal_width",
y="sepal_length",
color="species",
hover_data=["petal_length", "petal_width"],
)
for data in fig.data:
assert np.all(np.in1d(data.customdata[:, 1], iris.petal_width))
# Hover and custom data, no repeated arguments
fig = px.scatter(
iris,
x="sepal_width",
y="sepal_length",
hover_data=["petal_length", "petal_width"],
custom_data=["species_id", "species"],
)
assert np.all(fig.data[0].customdata[:, 0] == iris.species_id)
assert fig.data[0].customdata.shape[1] == 4
# Hover and custom data, with repeated arguments
fig = px.scatter(
iris,
x="sepal_width",
y="sepal_length",
hover_data=["petal_length", "petal_width", "species_id"],
custom_data=["species_id", "species"],
)
assert np.all(fig.data[0].customdata[:, 0] == iris.species_id)
assert fig.data[0].customdata.shape[1] == 4
assert (
fig.data[0].hovertemplate
== "sepal_width=%{x}<br>sepal_length=%{y}<br>petal_length=%{customdata[2]}<br>petal_width=%{customdata[3]}<br>species_id=%{customdata[0]}"
)