Skip to content

PX shouldn't modify attrs controlled by template #1875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,16 +375,18 @@ def configure_cartesian_marginal_axes(args, fig, orders):

# Configure axis ticks on marginal subplots
if args["marginal_x"]:
fig.update_yaxes(
showticklabels=False, showgrid=args["marginal_x"] == "histogram", row=nrows
)
fig.update_xaxes(showgrid=True, row=nrows)
fig.update_yaxes(showticklabels=False, row=nrows)
if args["template"].layout.yaxis.showgrid is None:
fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows)
if args["template"].layout.xaxis.showgrid is None:
fig.update_xaxes(showgrid=True, row=nrows)

if args["marginal_y"]:
fig.update_xaxes(
showticklabels=False, showgrid=args["marginal_y"] == "histogram", col=ncols
)
fig.update_yaxes(showgrid=True, col=ncols)
fig.update_xaxes(showticklabels=False, col=ncols)
if args["template"].layout.xaxis.showgrid is None:
fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols)
if args["template"].layout.yaxis.showgrid is None:
fig.update_yaxes(showgrid=True, col=ncols)

# Add axis titles to non-marginal subplots
y_title = get_decorated_label(args, args["y"], "y")
Expand Down Expand Up @@ -687,55 +689,47 @@ def apply_default_cascade(args):
else:
args["template"] = "plotly"

# retrieve the actual template if we were given a name
try:
template = pio.templates[args["template"]]
# retrieve the actual template if we were given a name
args["template"] = pio.templates[args["template"]]
except Exception:
template = args["template"]
# otherwise try to build a real template
args["template"] = go.layout.Template(args["template"])

# if colors not set explicitly or in px.defaults, defer to a template
# if the template doesn't have one, we set some final fallback defaults
if "color_continuous_scale" in args:
if args["color_continuous_scale"] is None:
try:
args["color_continuous_scale"] = [
x[1] for x in template.layout.colorscale.sequential
]
except (AttributeError, TypeError):
pass
if (
args["color_continuous_scale"] is None
and args["template"].layout.colorscale.sequential
):
args["color_continuous_scale"] = [
x[1] for x in args["template"].layout.colorscale.sequential
]
if args["color_continuous_scale"] is None:
args["color_continuous_scale"] = sequential.Viridis

if "color_discrete_sequence" in args:
if args["color_discrete_sequence"] is None:
try:
args["color_discrete_sequence"] = template.layout.colorway
except (AttributeError, TypeError):
pass
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
args["color_discrete_sequence"] = args["template"].layout.colorway
if args["color_discrete_sequence"] is None:
args["color_discrete_sequence"] = qualitative.D3

# if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
# see if we can defer to template. If not, set reasonable defaults
if "symbol_sequence" in args:
if args["symbol_sequence"] is None:
try:
args["symbol_sequence"] = [
scatter.marker.symbol for scatter in template.data.scatter
]
except (AttributeError, TypeError):
pass
if args["symbol_sequence"] is None and args["template"].data.scatter:
args["symbol_sequence"] = [
scatter.marker.symbol for scatter in args["template"].data.scatter
]
if not args["symbol_sequence"] or not any(args["symbol_sequence"]):
args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"]

if "line_dash_sequence" in args:
if args["line_dash_sequence"] is None:
try:
args["line_dash_sequence"] = [
scatter.line.dash for scatter in template.data.scatter
]
except (AttributeError, TypeError):
pass
if args["line_dash_sequence"] is None and args["template"].data.scatter:
args["line_dash_sequence"] = [
scatter.line.dash for scatter in args["template"].data.scatter
]
if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]):
args["line_dash_sequence"] = [
"solid",
Expand Down Expand Up @@ -1264,13 +1258,17 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
cmax=range_color[1],
colorbar=dict(title=get_decorated_label(args, args[colorvar], colorvar)),
)
for v in ["title", "height", "width", "template"]:
for v in ["title", "height", "width"]:
if args[v]:
layout_patch[v] = args[v]
layout_patch["legend"] = {"tracegroupgap": 0}
if "title" not in layout_patch:
if "title" not in layout_patch and args["template"].layout.margin.t is None:
layout_patch["margin"] = {"t": 60}
if "size" in args and args["size"]:
if (
"size" in args
and args["size"]
and args["template"].layout.legend.itemsizing is None
):
layout_patch["legend"]["itemsizing"] = "constant"

fig = init_figure(
Expand All @@ -1295,6 +1293,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
# Add traces, layout and frames to figure
fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else [])
fig.layout.update(layout_patch)
if "template" in args and args["template"] is not None:
fig.update_layout(template=args["template"], overwrite=True)
fig.frames = frame_list if len(frames) > 1 else []

fig._px_trendlines = pd.DataFrame(trendline_rows)
Expand Down
107 changes: 107 additions & 0 deletions packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,110 @@ def test_custom_data_scatter():
fig.data[0].hovertemplate
== "sepal_width=%{x}<br>sepal_length=%{y}<br>petal_length=%{customdata[2]}<br>petal_width=%{customdata[3]}<br>species_id=%{customdata[0]}"
)


def test_px_templates():
import plotly.io as pio
import plotly.graph_objects as go

tips = px.data.tips()

# use the normal defaults
fig = px.scatter()
assert fig.layout.template == pio.templates[pio.templates.default]

# respect changes to defaults
pio.templates.default = "seaborn"
fig = px.scatter()
assert fig.layout.template == pio.templates["seaborn"]

# special px-level defaults over pio defaults
pio.templates.default = "seaborn"
px.defaults.template = "ggplot2"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should document this in https://plot.ly/python/templates/, I did not know this was possible

fig = px.scatter()
assert fig.layout.template == pio.templates["ggplot2"]

# accept names in args over pio and px defaults
fig = px.scatter(template="seaborn")
assert fig.layout.template == pio.templates["seaborn"]

# accept objects in args
fig = px.scatter(template={})
assert fig.layout.template == go.layout.Template()

# read colorway from the template
fig = px.scatter(
tips,
x="total_bill",
y="tip",
color="sex",
template=dict(layout_colorway=["red", "blue"]),
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"

# default colorway fallback
fig = px.scatter(tips, x="total_bill", y="tip", color="sex", template=dict())
assert fig.data[0].marker.color == px.colors.qualitative.D3[0]
assert fig.data[1].marker.color == px.colors.qualitative.D3[1]

# pio default template colorway fallback
pio.templates.default = "seaborn"
px.defaults.template = None
fig = px.scatter(tips, x="total_bill", y="tip", color="sex")
assert fig.data[0].marker.color == pio.templates["seaborn"].layout.colorway[0]
assert fig.data[1].marker.color == pio.templates["seaborn"].layout.colorway[1]

# pio default template colorway fallback
pio.templates.default = "seaborn"
px.defaults.template = "ggplot2"
fig = px.scatter(tips, x="total_bill", y="tip", color="sex")
assert fig.data[0].marker.color == pio.templates["ggplot2"].layout.colorway[0]
assert fig.data[1].marker.color == pio.templates["ggplot2"].layout.colorway[1]

# don't overwrite top margin when set in template
fig = px.scatter(title="yo")
assert fig.layout.margin.t is None

fig = px.scatter()
assert fig.layout.margin.t == 60

fig = px.scatter(template=dict(layout_margin_t=2))
assert fig.layout.margin.t is None

# don't force histogram gridlines when set in template
pio.templates.default = "none"
px.defaults.template = None
fig = px.scatter(
tips, x="total_bill", y="tip", marginal_x="histogram", marginal_y="histogram"
)
assert fig.layout.xaxis2.showgrid
assert fig.layout.xaxis3.showgrid
assert fig.layout.yaxis2.showgrid
assert fig.layout.yaxis3.showgrid

fig = px.scatter(
tips,
x="total_bill",
y="tip",
marginal_x="histogram",
marginal_y="histogram",
template=dict(layout_yaxis_showgrid=False),
)
assert fig.layout.xaxis2.showgrid
assert fig.layout.xaxis3.showgrid
assert fig.layout.yaxis2.showgrid is None
assert fig.layout.yaxis3.showgrid is None

fig = px.scatter(
tips,
x="total_bill",
y="tip",
marginal_x="histogram",
marginal_y="histogram",
template=dict(layout_xaxis_showgrid=False),
)
assert fig.layout.xaxis2.showgrid is None
assert fig.layout.xaxis3.showgrid is None
assert fig.layout.yaxis2.showgrid
assert fig.layout.yaxis3.showgrid