Skip to content

Commit f367c72

Browse files
imshow defaults cascade
1 parent d4a4621 commit f367c72

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

Diff for: packages/python/plotly/plotly/express/_imshow.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import plotly.graph_objs as go
22
from _plotly_utils.basevalidators import ColorscaleValidator
3+
from ._core import apply_default_cascade
34
import numpy as np # is it fine to depend on np here?
45

56
_float_types = []
@@ -63,6 +64,10 @@ def imshow(
6364
color_continuous_scale=None,
6465
color_continuous_midpoint=None,
6566
range_color=None,
67+
title=None,
68+
template=None,
69+
width=None,
70+
height=None,
6671
):
6772
"""
6873
Display an image, i.e. data on a 2D regular raster.
@@ -118,6 +123,9 @@ def imshow(
118123
In order to update and customize the returned figure, use
119124
`go.Figure.update_traces` or `go.Figure.update_layout`.
120125
"""
126+
args = locals()
127+
apply_default_cascade(args)
128+
121129
img = np.asanyarray(img)
122130
# Cast bools to uint8 (also one byte)
123131
if img.dtype == np.bool:
@@ -134,7 +142,9 @@ def imshow(
134142
colorscale_validator = ColorscaleValidator("colorscale", "imshow")
135143
range_color = range_color or [None, None]
136144
layout["coloraxis1"] = dict(
137-
colorscale=colorscale_validator.validate_coerce(color_continuous_scale),
145+
colorscale=colorscale_validator.validate_coerce(
146+
args["color_continuous_scale"]
147+
),
138148
cmid=color_continuous_midpoint,
139149
cmin=range_color[0],
140150
cmax=range_color[1],
@@ -154,5 +164,14 @@ def imshow(
154164
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
155165
"An image of shape %s was provided" % str(img.shape)
156166
)
167+
168+
layout_patch = dict()
169+
for v in ["title", "height", "width"]:
170+
if args[v]:
171+
layout_patch[v] = args[v]
172+
if "title" not in layout_patch and args["template"].layout.margin.t is None:
173+
layout_patch["margin"] = {"t": 60}
157174
fig = go.Figure(data=trace, layout=layout)
175+
fig.update_layout(layout_patch)
176+
fig.update_layout(template=args["template"], overwrite=True)
158177
return fig

0 commit comments

Comments
 (0)