Skip to content

Commit d4a4621

Browse files
use three px-standard colorscale kwargs
1 parent 96967d7 commit d4a4621

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

Diff for: packages/python/plotly/plotly/express/_imshow.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import plotly.graph_objs as go
2+
from _plotly_utils.basevalidators import ColorscaleValidator
23
import numpy as np # is it fine to depend on np here?
34

45
_float_types = []
@@ -54,7 +55,15 @@ def _infer_zmax_from_type(img):
5455
return 2 ** 32
5556

5657

57-
def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
58+
def imshow(
59+
img,
60+
zmin=None,
61+
zmax=None,
62+
origin=None,
63+
color_continuous_scale=None,
64+
color_continuous_midpoint=None,
65+
range_color=None,
66+
):
5867
"""
5968
Display an image, i.e. data on a 2D regular raster.
6069
@@ -74,16 +83,24 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
7483
zmin and zmax correspond to the min and max values of the datatype for integer
7584
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
7685
a multichannel image of floats, the max of the image is computed and zmax is the
77-
smallest power of 256 (1, 255, 65535) greater than this max value,
86+
smallest power of 256 (1, 255, 65535) greater than this max value,
7887
with a 5% tolerance. For a single-channel image, the max of the image is used.
7988
8089
origin : str, 'upper' or 'lower' (default 'upper')
8190
position of the [0, 0] pixel of the image array, in the upper left or lower left
8291
corner. The convention 'upper' is typically used for matrices and images.
8392
84-
colorscale : str
85-
colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
86-
RGB or RGBA images.
93+
color_continuous_scale : str or list of str
94+
colormap used to map scalar data to colors (for a 2D image). This parameter is
95+
not used for RGB or RGBA images.
96+
97+
color_continuous_midpoint : number
98+
If set, computes the bounds of the continuous color scale to have the desired
99+
midpoint.
100+
101+
range_color : list of two numbers
102+
If provided, overrides auto-scaling on the continuous color scale, including
103+
overriding `color_continuous_midpoint`.
87104
88105
Returns
89106
-------
@@ -108,14 +125,21 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
108125

109126
# For 2d data, use Heatmap trace
110127
if img.ndim == 2:
111-
if colorscale is None:
112-
colorscale = "gray"
113-
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale)
128+
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, coloraxis="coloraxis1")
114129
autorange = True if origin == "lower" else "reversed"
115130
layout = dict(
116131
xaxis=dict(scaleanchor="y", constrain="domain"),
117132
yaxis=dict(autorange=autorange, constrain="domain"),
118133
)
134+
colorscale_validator = ColorscaleValidator("colorscale", "imshow")
135+
range_color = range_color or [None, None]
136+
layout["coloraxis1"] = dict(
137+
colorscale=colorscale_validator.validate_coerce(color_continuous_scale),
138+
cmid=color_continuous_midpoint,
139+
cmin=range_color[0],
140+
cmax=range_color[1],
141+
)
142+
119143
# For 2D+RGB data, use Image trace
120144
elif img.ndim == 3 and img.shape[-1] in [3, 4]:
121145
if zmax is None and img.dtype is not np.uint8:
@@ -127,7 +151,7 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
127151
layout["yaxis"] = dict(autorange=True)
128152
else:
129153
raise ValueError(
130-
"px.imshow only accepts 2D grayscale, RGB or RGBA images. "
154+
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
131155
"An image of shape %s was provided" % str(img.shape)
132156
)
133157
fig = go.Figure(data=trace, layout=layout)

0 commit comments

Comments
 (0)