diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py
index fb334c1b973..9df5b21ac8a 100644
--- a/packages/python/plotly/plotly/express/__init__.py
+++ b/packages/python/plotly/plotly/express/__init__.py
@@ -55,6 +55,11 @@
get_trendline_results,
)
+from ._special_inputs import ( # noqa: F401
+ IdentityMap,
+ Constant,
+)
+
from . import data, colors # noqa: F401
__all__ = [
@@ -95,4 +100,6 @@
"colors",
"set_mapbox_access_token",
"get_trendline_results",
+ "IdentityMap",
+ "Constant",
]
diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py
index ff247942a3d..85a5ee2ea43 100644
--- a/packages/python/plotly/plotly/express/_core.py
+++ b/packages/python/plotly/plotly/express/_core.py
@@ -1,6 +1,7 @@
import plotly.graph_objs as go
import plotly.io as pio
from collections import namedtuple, OrderedDict
+from ._special_inputs import IdentityMap, Constant
from _plotly_utils.basevalidators import ColorscaleValidator
from .colors import qualitative, sequential
@@ -41,6 +42,7 @@ def __init__(self):
defaults = PxDefaults()
del PxDefaults
+
MAPBOX_TOKEN = None
@@ -137,11 +139,15 @@ def make_mapping(args, variable):
if variable == "dash":
arg_name = "line_dash"
vprefix = "line_dash"
+ if args[vprefix + "_map"] == "identity":
+ val_map = IdentityMap()
+ else:
+ val_map = args[vprefix + "_map"].copy()
return Mapping(
show_in_trace_name=True,
variable=variable,
grouper=args[arg_name],
- val_map=args[vprefix + "_map"].copy(),
+ val_map=val_map,
sequence=args[vprefix + "_sequence"],
updater=lambda trace, v: trace.update({parent: {variable: v}}),
facet=None,
@@ -919,6 +925,8 @@ def build_dataframe(args, attrables, array_attrables):
else:
df_output[df_input.columns] = df_input[df_input.columns]
+ constants = dict()
+
# Loop over possible arguments
for field_name in attrables:
# Massaging variables
@@ -950,8 +958,15 @@ def build_dataframe(args, attrables, array_attrables):
"pandas MultiIndex is not supported by plotly express "
"at the moment." % field
)
+ # ----------------- argument is a constant ----------------------
+ if isinstance(argument, Constant):
+ col_name = _check_name_not_reserved(
+ str(argument.label) if argument.label is not None else field,
+ reserved_names,
+ )
+ constants[col_name] = argument.value
# ----------------- argument is a col name ----------------------
- if isinstance(argument, str) or isinstance(
+ elif isinstance(argument, str) or isinstance(
argument, int
): # just a column name given as str or int
if not df_provided:
@@ -1032,6 +1047,9 @@ def build_dataframe(args, attrables, array_attrables):
else:
args[field_name][i] = str(col_name)
+ for col_name in constants:
+ df_output[col_name] = constants[col_name]
+
args["data_frame"] = df_output
return args
@@ -1402,9 +1420,10 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
for col, val, m in zip(grouper, group_name, grouped_mappings):
if col != one_group:
key = get_label(args, col)
- mapping_labels[key] = str(val)
- if m.show_in_trace_name:
- trace_name_labels[key] = str(val)
+ if not isinstance(m.val_map, IdentityMap):
+ mapping_labels[key] = str(val)
+ if m.show_in_trace_name:
+ trace_name_labels[key] = str(val)
if m.variable == "animation_frame":
frame_name = val
trace_name = ", ".join(trace_name_labels.values())
diff --git a/packages/python/plotly/plotly/express/_special_inputs.py b/packages/python/plotly/plotly/express/_special_inputs.py
new file mode 100644
index 00000000000..3dfff0f3c8e
--- /dev/null
+++ b/packages/python/plotly/plotly/express/_special_inputs.py
@@ -0,0 +1,29 @@
+class IdentityMap(object):
+ """
+ `dict`-like object which acts as if the value for any key is the key itself. Objects
+ of this class can be passed in to arguments like `color_discrete_map` to
+ use the provided data values as colors, rather than mapping them to colors cycled
+ from `color_discrete_sequence`. This works for any `_map` argument to Plotly Express
+ functions, such as `line_dash_map` and `symbol_map`.
+ """
+
+ def __getitem__(self, key):
+ return key
+
+ def __contains__(self, key):
+ return True
+
+ def copy(self):
+ return self
+
+
+class Constant(object):
+ """
+ Objects of this class can be passed to Plotly Express functions that expect column
+ identifiers or list-like objects to indicate that this attribute should take on a
+ constant value. An optional label can be provided.
+ """
+
+ def __init__(self, value, label=None):
+ self.value = value
+ self.label = label
diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py
index e3786f6af90..b075e197d79 100644
--- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py
+++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py
@@ -323,3 +323,61 @@ def test_size_column():
df = px.data.tips()
fig = px.scatter(df, x=df["size"], y=df.tip)
assert fig.data[0].hovertemplate == "size=%{x}
tip=%{y}"
+
+
+def test_identity_map():
+ fig = px.scatter(
+ x=[1, 2],
+ y=[1, 2],
+ symbol=["a", "b"],
+ color=["red", "blue"],
+ color_discrete_map=px.IdentityMap(),
+ )
+ assert fig.data[0].marker.color == "red"
+ assert fig.data[1].marker.color == "blue"
+ assert "color=" not in fig.data[0].hovertemplate
+ assert "symbol=" in fig.data[0].hovertemplate
+ assert fig.layout.legend.title.text == "symbol"
+
+ fig = px.scatter(
+ x=[1, 2],
+ y=[1, 2],
+ symbol=["a", "b"],
+ color=["red", "blue"],
+ color_discrete_map="identity",
+ )
+ assert fig.data[0].marker.color == "red"
+ assert fig.data[1].marker.color == "blue"
+ assert "color=" not in fig.data[0].hovertemplate
+ assert "symbol=" in fig.data[0].hovertemplate
+ assert fig.layout.legend.title.text == "symbol"
+
+
+def test_constants():
+ fig = px.scatter(x=px.Constant(1), y=[1, 2])
+ assert fig.data[0].x[0] == 1
+ assert fig.data[0].x[1] == 1
+ assert "x=" in fig.data[0].hovertemplate
+
+ fig = px.scatter(x=px.Constant(1, label="time"), y=[1, 2])
+ assert fig.data[0].x[0] == 1
+ assert fig.data[0].x[1] == 1
+ assert "x=" not in fig.data[0].hovertemplate
+ assert "time=" in fig.data[0].hovertemplate
+
+ fig = px.scatter(
+ x=[1, 2],
+ y=[1, 2],
+ symbol=["a", "b"],
+ color=px.Constant("red", label="the_identity_label"),
+ hover_data=[px.Constant("data", label="the_data")],
+ color_discrete_map=px.IdentityMap(),
+ )
+ assert fig.data[0].marker.color == "red"
+ assert fig.data[0].customdata[0][0] == "data"
+ assert fig.data[1].marker.color == "red"
+ assert "color=" not in fig.data[0].hovertemplate
+ assert "the_identity_label=" not in fig.data[0].hovertemplate
+ assert "symbol=" in fig.data[0].hovertemplate
+ assert "the_data=" in fig.data[0].hovertemplate
+ assert fig.layout.legend.title.text == "symbol"