Skip to content

Px special inputs #2330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/python/plotly/plotly/express/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
set_mapbox_access_token,
defaults,
get_trendline_results,
IdentityMap,
Constant,
)

from . import data, colors # noqa: F401
Expand Down Expand Up @@ -95,4 +97,6 @@
"colors",
"set_mapbox_access_token",
"get_trendline_results",
"IdentityMap",
"Constant",
]
45 changes: 41 additions & 4 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,30 @@ def __init__(self):
defaults = PxDefaults()
del PxDefaults


class IdentityMap(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could these two classes be moved to a _special_inputs.py file to shorten a bit _core.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, if needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as you wish, I just feel that the length of this file makes it a bit overwhelming.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

"""
`dict`-like object which can be passed in to arguments like `color_discrete_map` to
use the provided data values as colors, rather than mapping them to colors cycled
from `color_discrete_sequence`.
"""

def __getitem__(self, key):
return key

def __contains__(self, key):
return True

def copy(self):
return self


class Constant(object):
def __init__(self, value, label=None):
self.value = value
self.label = label


MAPBOX_TOKEN = None


Expand Down Expand Up @@ -919,6 +943,8 @@ def build_dataframe(args, attrables, array_attrables):
else:
df_output[df_input.columns] = df_input[df_input.columns]

constants = dict()

# Loop over possible arguments
for field_name in attrables:
# Massaging variables
Expand Down Expand Up @@ -950,8 +976,15 @@ def build_dataframe(args, attrables, array_attrables):
"pandas MultiIndex is not supported by plotly express "
"at the moment." % field
)
# ----------------- argument is a constant ----------------------
if isinstance(argument, Constant):
col_name = _check_name_not_reserved(
str(argument.label) if argument.label is not None else field,
reserved_names,
)
constants[col_name] = argument.value
# ----------------- argument is a col name ----------------------
if isinstance(argument, str) or isinstance(
elif isinstance(argument, str) or isinstance(
argument, int
): # just a column name given as str or int
if not df_provided:
Expand Down Expand Up @@ -1032,6 +1065,9 @@ def build_dataframe(args, attrables, array_attrables):
else:
args[field_name][i] = str(col_name)

for col_name in constants:
df_output[col_name] = constants[col_name]

args["data_frame"] = df_output
return args

Expand Down Expand Up @@ -1402,9 +1438,10 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
for col, val, m in zip(grouper, group_name, grouped_mappings):
if col != one_group:
key = get_label(args, col)
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if not isinstance(m.val_map, IdentityMap):
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if m.variable == "animation_frame":
frame_name = val
trace_name = ", ".join(trace_name_labels.values())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,48 @@ def test_size_column():
df = px.data.tips()
fig = px.scatter(df, x=df["size"], y=df.tip)
assert fig.data[0].hovertemplate == "size=%{x}<br>tip=%{y}<extra></extra>"


def test_identity_map():
fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=["red", "blue"],
color_discrete_map=px.IdentityMap(),
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"
assert "color=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"


def test_constants():
fig = px.scatter(x=px.Constant(1), y=[1, 2])
assert fig.data[0].x[0] == 1
assert fig.data[0].x[1] == 1
assert "x=" in fig.data[0].hovertemplate

fig = px.scatter(x=px.Constant(1, label="time"), y=[1, 2])
assert fig.data[0].x[0] == 1
assert fig.data[0].x[1] == 1
assert "x=" not in fig.data[0].hovertemplate
assert "time=" in fig.data[0].hovertemplate

fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=px.Constant("red", label="the_identity_label"),
hover_data=[px.Constant("data", label="the_data")],
color_discrete_map=px.IdentityMap(),
)
assert fig.data[0].marker.color == "red"
assert fig.data[0].customdata[0][0] == "data"
assert fig.data[1].marker.color == "red"
assert "color=" not in fig.data[0].hovertemplate
assert "the_identity_label=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert "the_data=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"