Skip to content

Px special inputs #2330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/python/plotly/plotly/express/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
get_trendline_results,
)

from ._special_inputs import ( # noqa: F401
IdentityMap,
Constant,
)

from . import data, colors # noqa: F401

__all__ = [
Expand Down Expand Up @@ -95,4 +100,6 @@
"colors",
"set_mapbox_access_token",
"get_trendline_results",
"IdentityMap",
"Constant",
]
29 changes: 24 additions & 5 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import plotly.graph_objs as go
import plotly.io as pio
from collections import namedtuple, OrderedDict
from ._special_inputs import IdentityMap, Constant

from _plotly_utils.basevalidators import ColorscaleValidator
from .colors import qualitative, sequential
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(self):
defaults = PxDefaults()
del PxDefaults


MAPBOX_TOKEN = None


Expand Down Expand Up @@ -137,11 +139,15 @@ def make_mapping(args, variable):
if variable == "dash":
arg_name = "line_dash"
vprefix = "line_dash"
if args[vprefix + "_map"] == "identity":
val_map = IdentityMap()
else:
val_map = args[vprefix + "_map"].copy()
return Mapping(
show_in_trace_name=True,
variable=variable,
grouper=args[arg_name],
val_map=args[vprefix + "_map"].copy(),
val_map=val_map,
sequence=args[vprefix + "_sequence"],
updater=lambda trace, v: trace.update({parent: {variable: v}}),
facet=None,
Expand Down Expand Up @@ -919,6 +925,8 @@ def build_dataframe(args, attrables, array_attrables):
else:
df_output[df_input.columns] = df_input[df_input.columns]

constants = dict()

# Loop over possible arguments
for field_name in attrables:
# Massaging variables
Expand Down Expand Up @@ -950,8 +958,15 @@ def build_dataframe(args, attrables, array_attrables):
"pandas MultiIndex is not supported by plotly express "
"at the moment." % field
)
# ----------------- argument is a constant ----------------------
if isinstance(argument, Constant):
col_name = _check_name_not_reserved(
str(argument.label) if argument.label is not None else field,
reserved_names,
)
constants[col_name] = argument.value
# ----------------- argument is a col name ----------------------
if isinstance(argument, str) or isinstance(
elif isinstance(argument, str) or isinstance(
argument, int
): # just a column name given as str or int
if not df_provided:
Expand Down Expand Up @@ -1032,6 +1047,9 @@ def build_dataframe(args, attrables, array_attrables):
else:
args[field_name][i] = str(col_name)

for col_name in constants:
df_output[col_name] = constants[col_name]

args["data_frame"] = df_output
return args

Expand Down Expand Up @@ -1402,9 +1420,10 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
for col, val, m in zip(grouper, group_name, grouped_mappings):
if col != one_group:
key = get_label(args, col)
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if not isinstance(m.val_map, IdentityMap):
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if m.variable == "animation_frame":
frame_name = val
trace_name = ", ".join(trace_name_labels.values())
Expand Down
29 changes: 29 additions & 0 deletions packages/python/plotly/plotly/express/_special_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class IdentityMap(object):
"""
`dict`-like object which acts as if the value for any key is the key itself. Objects
of this class can be passed in to arguments like `color_discrete_map` to
use the provided data values as colors, rather than mapping them to colors cycled
from `color_discrete_sequence`. This works for any `_map` argument to Plotly Express
functions, such as `line_dash_map` and `symbol_map`.
"""

def __getitem__(self, key):
return key

def __contains__(self, key):
return True

def copy(self):
return self


class Constant(object):
"""
Objects of this class can be passed to Plotly Express functions that expect column
identifiers or list-like objects to indicate that this attribute should take on a
constant value. An optional label can be provided.
"""

def __init__(self, value, label=None):
self.value = value
self.label = label
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,61 @@ def test_size_column():
df = px.data.tips()
fig = px.scatter(df, x=df["size"], y=df.tip)
assert fig.data[0].hovertemplate == "size=%{x}<br>tip=%{y}<extra></extra>"


def test_identity_map():
fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=["red", "blue"],
color_discrete_map=px.IdentityMap(),
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"
assert "color=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"

fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=["red", "blue"],
color_discrete_map="identity",
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"
assert "color=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"


def test_constants():
fig = px.scatter(x=px.Constant(1), y=[1, 2])
assert fig.data[0].x[0] == 1
assert fig.data[0].x[1] == 1
assert "x=" in fig.data[0].hovertemplate

fig = px.scatter(x=px.Constant(1, label="time"), y=[1, 2])
assert fig.data[0].x[0] == 1
assert fig.data[0].x[1] == 1
assert "x=" not in fig.data[0].hovertemplate
assert "time=" in fig.data[0].hovertemplate

fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=px.Constant("red", label="the_identity_label"),
hover_data=[px.Constant("data", label="the_data")],
color_discrete_map=px.IdentityMap(),
)
assert fig.data[0].marker.color == "red"
assert fig.data[0].customdata[0][0] == "data"
assert fig.data[1].marker.color == "red"
assert "color=" not in fig.data[0].hovertemplate
assert "the_identity_label=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert "the_data=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"