Skip to content

Commit 15ce789

Browse files
committed
first version of imshow
1 parent b2151ef commit 15ce789

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed

Diff for: packages/python/plotly/plotly/express/_imshow.py

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import plotly.graph_objs as go
2+
import numpy as np # is it fine to depend on np here?
3+
4+
_float_types = []
5+
6+
# Adapted from skimage.util.dtype
7+
_integer_types = (
8+
np.byte,
9+
np.ubyte, # 8 bits
10+
np.short,
11+
np.ushort, # 16 bits
12+
np.intc,
13+
np.uintc, # 16 or 32 or 64 bits
14+
np.int_,
15+
np.uint, # 32 or 64 bits
16+
np.longlong,
17+
np.ulonglong,
18+
) # 64 bits
19+
_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types}
20+
21+
22+
def _vectorize_zvalue(z):
23+
if z is None:
24+
return z
25+
elif np.isscalar(z):
26+
return [z] * 3 + [1]
27+
elif len(z) == 1:
28+
return list(z) * 3 + [1]
29+
elif len(z) == 3:
30+
return list(z) + [1]
31+
elif len(z) == 4:
32+
return z
33+
else:
34+
raise ValueError(
35+
"zmax can be a scalar, or an iterable of length 1, 3 or 4. "
36+
"A value of %s was passed for zmax." % str(z)
37+
)
38+
39+
40+
def _infer_zmax_from_type(img):
41+
dt = img.dtype.type
42+
if dt in _integer_types:
43+
return _integer_ranges[dt][1]
44+
else:
45+
return img[np.isfinite(img)].max()
46+
47+
48+
def imshow(
49+
img, zmin=None, zmax=None, origin=None, colorscale=None, showticks=True, **kwargs
50+
):
51+
"""
52+
Display an image, i.e. data on a 2D regular raster.
53+
54+
Parameters
55+
----------
56+
57+
img: array-like image
58+
The image data. Supported array shapes are
59+
60+
- (M, N): an image with scalar data. The data is visualized
61+
using a colormap.
62+
- (M, N, 3): an image with RGB values.
63+
- (M, N, 4): an image with RGBA values, i.e. including transparency.
64+
65+
zmin, zmax : scalar or iterable, optional
66+
zmin and zmax define the scalar range that the colormap covers. By default,
67+
zmin and zmax correspond to the min and max values of the datatype for integer
68+
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.), and
69+
to the min and max values of the image for an image of floats.
70+
71+
origin : str, 'upper' or 'lower' (default 'upper')
72+
position of the [0, 0] pixel of the image array, in the upper left or lower left
73+
corner. The convention 'upper' is typically used for matrices and images.
74+
75+
colorscale : str
76+
colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
77+
RGB or RGBA images.
78+
79+
showticks : bool, default True
80+
if False, no tick labels are shown for pixel indices.
81+
82+
** kwargs : additional arguments to be passed to the Heatmap (grayscale) or Image (RGB) trace.
83+
84+
Returns
85+
-------
86+
fig : graph_objects.Figure containing the displayed image
87+
88+
See also
89+
--------
90+
91+
graph_objects.Image : image trace
92+
graph_objects.Heatmap : heatmap trace
93+
"""
94+
img = np.asanyarray(img)
95+
# Cast bools to uint8 (also one byte)
96+
if img.dtype == np.bool:
97+
img = 255 * img.astype(np.uint8)
98+
99+
# For 2d data, use Heatmap trace
100+
if img.ndim == 2:
101+
if colorscale is None:
102+
colorscale = "gray"
103+
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale, **kwargs)
104+
autorange = True if origin == "lower" else "reversed"
105+
layout = dict(
106+
xaxis=dict(scaleanchor="y", constrain="domain"),
107+
yaxis=dict(autorange=autorange, constrain="domain"),
108+
)
109+
# For 2D+RGB data, use Image trace
110+
elif img.ndim == 3 and img.shape[-1] in [3, 4]:
111+
if zmax is None and img.dtype is not np.uint8:
112+
zmax = _infer_zmax_from_type(img)
113+
zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax)
114+
trace = go.Image(z=img, zmin=zmin, zmax=zmax, **kwargs)
115+
layout = {}
116+
if origin == "lower":
117+
layout["yaxis"] = dict(autorange=True)
118+
else:
119+
raise ValueError(
120+
"px.imshow only accepts 2D grayscale, RGB or RGBA images. "
121+
"An image of shape %s was provided" % str(img.shape)
122+
)
123+
fig = go.Figure(data=trace, layout=layout)
124+
if not showticks:
125+
fig.update_xaxes(showticklabels=False)
126+
fig.update_yaxes(showticklabels=False)
127+
return fig
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import plotly.express as px
2+
import numpy as np
3+
import pytest
4+
5+
img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8)
6+
img_gray = np.arange(100).reshape((10, 10))
7+
8+
9+
def test_rgb_uint8():
10+
fig = px.imshow(img_rgb)
11+
assert fig.data[0]["zmax"] == (255, 255, 255, 1)
12+
13+
14+
def test_vmax():
15+
for zmax in [
16+
100,
17+
[100],
18+
(100,),
19+
[100, 100, 100],
20+
(100, 100, 100),
21+
(100, 100, 100, 1),
22+
]:
23+
fig = px.imshow(img_rgb, zmax=zmax)
24+
assert fig.data[0]["zmax"] == (100, 100, 100, 1)
25+
26+
27+
def test_automatic_zmax_from_dtype():
28+
dtypes_dict = {
29+
np.uint8: 2 ** 8 - 1,
30+
np.uint16: 2 ** 16 - 1,
31+
np.float: 1,
32+
np.bool: 255,
33+
}
34+
for key, val in dtypes_dict.items():
35+
img = np.array([0, 1], dtype=key)
36+
img = np.dstack((img,) * 3)
37+
fig = px.imshow(img)
38+
assert fig.data[0]["zmax"] == (val, val, val, 1)
39+
40+
41+
def test_origin():
42+
for img in [img_rgb, img_gray]:
43+
fig = px.imshow(img, origin="lower")
44+
assert fig.layout.yaxis.autorange == True
45+
fig = px.imshow(img_rgb)
46+
assert fig.layout.yaxis.autorange is None
47+
fig = px.imshow(img_gray)
48+
assert fig.layout.yaxis.autorange == "reversed"
49+
50+
51+
def test_colorscale():
52+
fig = px.imshow(img_gray)
53+
assert fig.data[0].colorscale[0] == (0.0, "rgb(0, 0, 0)")
54+
fig = px.imshow(img_gray, colorscale="Viridis")
55+
assert fig.data[0].colorscale[0] == (0.0, "#440154")
56+
57+
58+
def test_wrong_dimensions():
59+
imgs = [1, np.ones((5,) * 3), np.ones((5,) * 4)]
60+
for img in imgs:
61+
with pytest.raises(ValueError) as err_msg:
62+
fig = px.imshow(img)
63+
64+
65+
def test_nan_inf_data():
66+
imgs = [np.ones((20, 20)), 255 * np.ones((20, 20), dtype=np.uint8)]
67+
zmaxs = [1, 255]
68+
for zmax, img in zip(zmaxs, imgs):
69+
img[0] = 0
70+
img[10:12] = np.nan
71+
# the case of 2d/heatmap is handled gracefully by the JS trace but I don't know how to check it
72+
fig = px.imshow(np.dstack((img,) * 3))
73+
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1)

0 commit comments

Comments
 (0)