1
1
import plotly .graph_objs as go
2
+ from _plotly_utils .basevalidators import ColorscaleValidator
2
3
import numpy as np # is it fine to depend on np here?
3
4
4
5
_float_types = []
@@ -54,7 +55,15 @@ def _infer_zmax_from_type(img):
54
55
return 2 ** 32
55
56
56
57
57
- def imshow (img , zmin = None , zmax = None , origin = None , colorscale = None ):
58
+ def imshow (
59
+ img ,
60
+ zmin = None ,
61
+ zmax = None ,
62
+ origin = None ,
63
+ color_continuous_scale = None ,
64
+ color_continuous_midpoint = None ,
65
+ range_color = None ,
66
+ ):
58
67
"""
59
68
Display an image, i.e. data on a 2D regular raster.
60
69
@@ -74,16 +83,24 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
74
83
zmin and zmax correspond to the min and max values of the datatype for integer
75
84
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
76
85
a multichannel image of floats, the max of the image is computed and zmax is the
77
- smallest power of 256 (1, 255, 65535) greater than this max value,
86
+ smallest power of 256 (1, 255, 65535) greater than this max value,
78
87
with a 5% tolerance. For a single-channel image, the max of the image is used.
79
88
80
89
origin : str, 'upper' or 'lower' (default 'upper')
81
90
position of the [0, 0] pixel of the image array, in the upper left or lower left
82
91
corner. The convention 'upper' is typically used for matrices and images.
83
92
84
- colorscale : str
85
- colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
86
- RGB or RGBA images.
93
+ color_continuous_scale : str or list of str
94
+ colormap used to map scalar data to colors (for a 2D image). This parameter is
95
+ not used for RGB or RGBA images.
96
+
97
+ color_continuous_midpoint : number
98
+ If set, computes the bounds of the continuous color scale to have the desired
99
+ midpoint.
100
+
101
+ range_color : list of two numbers
102
+ If provided, overrides auto-scaling on the continuous color scale, including
103
+ overriding `color_continuous_midpoint`.
87
104
88
105
Returns
89
106
-------
@@ -108,14 +125,21 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
108
125
109
126
# For 2d data, use Heatmap trace
110
127
if img .ndim == 2 :
111
- if colorscale is None :
112
- colorscale = "gray"
113
- trace = go .Heatmap (z = img , zmin = zmin , zmax = zmax , colorscale = colorscale )
128
+ trace = go .Heatmap (z = img , zmin = zmin , zmax = zmax , coloraxis = "coloraxis1" )
114
129
autorange = True if origin == "lower" else "reversed"
115
130
layout = dict (
116
131
xaxis = dict (scaleanchor = "y" , constrain = "domain" ),
117
132
yaxis = dict (autorange = autorange , constrain = "domain" ),
118
133
)
134
+ colorscale_validator = ColorscaleValidator ("colorscale" , "imshow" )
135
+ range_color = range_color or [None , None ]
136
+ layout ["coloraxis1" ] = dict (
137
+ colorscale = colorscale_validator .validate_coerce (color_continuous_scale ),
138
+ cmid = color_continuous_midpoint ,
139
+ cmin = range_color [0 ],
140
+ cmax = range_color [1 ],
141
+ )
142
+
119
143
# For 2D+RGB data, use Image trace
120
144
elif img .ndim == 3 and img .shape [- 1 ] in [3 , 4 ]:
121
145
if zmax is None and img .dtype is not np .uint8 :
@@ -127,7 +151,7 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
127
151
layout ["yaxis" ] = dict (autorange = True )
128
152
else :
129
153
raise ValueError (
130
- "px.imshow only accepts 2D grayscale , RGB or RGBA images. "
154
+ "px.imshow only accepts 2D single-channel , RGB or RGBA images. "
131
155
"An image of shape %s was provided" % str (img .shape )
132
156
)
133
157
fig = go .Figure (data = trace , layout = layout )
0 commit comments