diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py index fb334c1b973..9df5b21ac8a 100644 --- a/packages/python/plotly/plotly/express/__init__.py +++ b/packages/python/plotly/plotly/express/__init__.py @@ -55,6 +55,11 @@ get_trendline_results, ) +from ._special_inputs import ( # noqa: F401 + IdentityMap, + Constant, +) + from . import data, colors # noqa: F401 __all__ = [ @@ -95,4 +100,6 @@ "colors", "set_mapbox_access_token", "get_trendline_results", + "IdentityMap", + "Constant", ] diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index ff247942a3d..85a5ee2ea43 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1,6 +1,7 @@ import plotly.graph_objs as go import plotly.io as pio from collections import namedtuple, OrderedDict +from ._special_inputs import IdentityMap, Constant from _plotly_utils.basevalidators import ColorscaleValidator from .colors import qualitative, sequential @@ -41,6 +42,7 @@ def __init__(self): defaults = PxDefaults() del PxDefaults + MAPBOX_TOKEN = None @@ -137,11 +139,15 @@ def make_mapping(args, variable): if variable == "dash": arg_name = "line_dash" vprefix = "line_dash" + if args[vprefix + "_map"] == "identity": + val_map = IdentityMap() + else: + val_map = args[vprefix + "_map"].copy() return Mapping( show_in_trace_name=True, variable=variable, grouper=args[arg_name], - val_map=args[vprefix + "_map"].copy(), + val_map=val_map, sequence=args[vprefix + "_sequence"], updater=lambda trace, v: trace.update({parent: {variable: v}}), facet=None, @@ -919,6 +925,8 @@ def build_dataframe(args, attrables, array_attrables): else: df_output[df_input.columns] = df_input[df_input.columns] + constants = dict() + # Loop over possible arguments for field_name in attrables: # Massaging variables @@ -950,8 +958,15 @@ def build_dataframe(args, attrables, array_attrables): "pandas MultiIndex is not supported by plotly express " "at the moment." % field ) + # ----------------- argument is a constant ---------------------- + if isinstance(argument, Constant): + col_name = _check_name_not_reserved( + str(argument.label) if argument.label is not None else field, + reserved_names, + ) + constants[col_name] = argument.value # ----------------- argument is a col name ---------------------- - if isinstance(argument, str) or isinstance( + elif isinstance(argument, str) or isinstance( argument, int ): # just a column name given as str or int if not df_provided: @@ -1032,6 +1047,9 @@ def build_dataframe(args, attrables, array_attrables): else: args[field_name][i] = str(col_name) + for col_name in constants: + df_output[col_name] = constants[col_name] + args["data_frame"] = df_output return args @@ -1402,9 +1420,10 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): for col, val, m in zip(grouper, group_name, grouped_mappings): if col != one_group: key = get_label(args, col) - mapping_labels[key] = str(val) - if m.show_in_trace_name: - trace_name_labels[key] = str(val) + if not isinstance(m.val_map, IdentityMap): + mapping_labels[key] = str(val) + if m.show_in_trace_name: + trace_name_labels[key] = str(val) if m.variable == "animation_frame": frame_name = val trace_name = ", ".join(trace_name_labels.values()) diff --git a/packages/python/plotly/plotly/express/_special_inputs.py b/packages/python/plotly/plotly/express/_special_inputs.py new file mode 100644 index 00000000000..3dfff0f3c8e --- /dev/null +++ b/packages/python/plotly/plotly/express/_special_inputs.py @@ -0,0 +1,29 @@ +class IdentityMap(object): + """ + `dict`-like object which acts as if the value for any key is the key itself. Objects + of this class can be passed in to arguments like `color_discrete_map` to + use the provided data values as colors, rather than mapping them to colors cycled + from `color_discrete_sequence`. This works for any `_map` argument to Plotly Express + functions, such as `line_dash_map` and `symbol_map`. + """ + + def __getitem__(self, key): + return key + + def __contains__(self, key): + return True + + def copy(self): + return self + + +class Constant(object): + """ + Objects of this class can be passed to Plotly Express functions that expect column + identifiers or list-like objects to indicate that this attribute should take on a + constant value. An optional label can be provided. + """ + + def __init__(self, value, label=None): + self.value = value + self.label = label diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index e3786f6af90..b075e197d79 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -323,3 +323,61 @@ def test_size_column(): df = px.data.tips() fig = px.scatter(df, x=df["size"], y=df.tip) assert fig.data[0].hovertemplate == "size=%{x}
tip=%{y}" + + +def test_identity_map(): + fig = px.scatter( + x=[1, 2], + y=[1, 2], + symbol=["a", "b"], + color=["red", "blue"], + color_discrete_map=px.IdentityMap(), + ) + assert fig.data[0].marker.color == "red" + assert fig.data[1].marker.color == "blue" + assert "color=" not in fig.data[0].hovertemplate + assert "symbol=" in fig.data[0].hovertemplate + assert fig.layout.legend.title.text == "symbol" + + fig = px.scatter( + x=[1, 2], + y=[1, 2], + symbol=["a", "b"], + color=["red", "blue"], + color_discrete_map="identity", + ) + assert fig.data[0].marker.color == "red" + assert fig.data[1].marker.color == "blue" + assert "color=" not in fig.data[0].hovertemplate + assert "symbol=" in fig.data[0].hovertemplate + assert fig.layout.legend.title.text == "symbol" + + +def test_constants(): + fig = px.scatter(x=px.Constant(1), y=[1, 2]) + assert fig.data[0].x[0] == 1 + assert fig.data[0].x[1] == 1 + assert "x=" in fig.data[0].hovertemplate + + fig = px.scatter(x=px.Constant(1, label="time"), y=[1, 2]) + assert fig.data[0].x[0] == 1 + assert fig.data[0].x[1] == 1 + assert "x=" not in fig.data[0].hovertemplate + assert "time=" in fig.data[0].hovertemplate + + fig = px.scatter( + x=[1, 2], + y=[1, 2], + symbol=["a", "b"], + color=px.Constant("red", label="the_identity_label"), + hover_data=[px.Constant("data", label="the_data")], + color_discrete_map=px.IdentityMap(), + ) + assert fig.data[0].marker.color == "red" + assert fig.data[0].customdata[0][0] == "data" + assert fig.data[1].marker.color == "red" + assert "color=" not in fig.data[0].hovertemplate + assert "the_identity_label=" not in fig.data[0].hovertemplate + assert "symbol=" in fig.data[0].hovertemplate + assert "the_data=" in fig.data[0].hovertemplate + assert fig.layout.legend.title.text == "symbol"