Skip to content

Commit 941779c

Browse files
generalizing imshow(labels, x, y)
1 parent 67dd774 commit 941779c

File tree

2 files changed

+48
-48
lines changed

2 files changed

+48
-48
lines changed

Diff for: doc/python/imshow.md

+10-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jupyter:
66
extension: .md
77
format_name: markdown
88
format_version: '1.2'
9-
jupytext_version: 1.3.0
9+
jupytext_version: 1.3.1
1010
kernelspec:
1111
display_name: Python 3
1212
language: python
@@ -20,7 +20,7 @@ jupyter:
2020
name: python
2121
nbconvert_exporter: python
2222
pygments_lexer: ipython3
23-
version: 3.7.3
23+
version: 3.6.8
2424
plotly:
2525
description: How to display image data in Python with Plotly.
2626
display_as: scientific
@@ -93,7 +93,11 @@ fig.show()
9393
import plotly.express as px
9494
import numpy as np
9595
img = np.arange(100).reshape((10, 10))
96-
fig = px.imshow(img, color_continuous_scale='gray')
96+
fig = px.imshow(img, color_continuous_scale='gray', labels=dict(x="yoo", y="yaa", color="hey"),
97+
width=600, height=600,
98+
x=["a","b","c","d","e","f","g","h","i","j"],
99+
y=["a","b","c","d","e","f","g","h","i","j"]
100+
)
97101
fig.show()
98102
```
99103

@@ -120,7 +124,7 @@ import xarray as xr
120124
# Load xarray from dataset included in the xarray tutorial
121125
airtemps = xr.tutorial.open_dataset('air_temperature').air.sel(lon=250.0)
122126
fig = px.imshow(airtemps.T, color_continuous_scale='RdBu_r', origin='lower',
123-
#labels={'colorbar':airtemps.attrs['var_desc']}
127+
labels={'color':airtemps.attrs['var_desc']}
124128
)
125129
fig.show()
126130
```
@@ -133,9 +137,8 @@ For xarrays, by default `px.imshow` does not constrain pixels to be square, sinc
133137
import plotly.express as px
134138
import xarray as xr
135139
airtemps = xr.tutorial.open_dataset('air_temperature').air.isel(time=500)
136-
colorbar_title = airtemps.attrs['var_desc'] + '<br> (%s)'%airtemps.attrs['units']
137-
fig = px.imshow(airtemps, color_continuous_scale='RdBu_r', aspect='equal',
138-
labels={'colorbar':colorbar_title})
140+
colorbar_title = airtemps.attrs['var_desc'] + '<br>(%s)'%airtemps.attrs['units']
141+
fig = px.imshow(airtemps, color_continuous_scale='RdBu_r', aspect='equal')
139142
fig.show()
140143
```
141144

Diff for: packages/python/plotly/plotly/express/_imshow.py

+38-41
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def imshow(
6969
zmax=None,
7070
origin=None,
7171
labels={},
72+
x=None,
73+
y=None,
7274
color_continuous_scale=None,
7375
color_continuous_midpoint=None,
7476
range_color=None,
@@ -167,33 +169,33 @@ def imshow(
167169
"""
168170
args = locals()
169171
apply_default_cascade(args)
170-
img_is_xarray = False
171-
x_label = "x"
172-
y_label = "y"
173-
z_label = ""
174-
if xarray_imported:
175-
if isinstance(img, xarray.DataArray):
176-
y_label, x_label = img.dims[0], img.dims[1]
177-
# np.datetime64 is not handled correctly by go.Heatmap
178-
for ax in [x_label, y_label]:
179-
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
180-
img.coords[ax] = img.coords[ax].astype(str)
172+
labels = labels.copy()
173+
if xarray_imported and isinstance(img, xarray.DataArray):
174+
y_label, x_label = img.dims[0], img.dims[1]
175+
# np.datetime64 is not handled correctly by go.Heatmap
176+
for ax in [x_label, y_label]:
177+
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
178+
img.coords[ax] = img.coords[ax].astype(str)
179+
if x is None:
181180
x = img.coords[x_label]
181+
if y is None:
182182
y = img.coords[y_label]
183-
img_is_xarray = True
184-
if aspect is None:
185-
aspect = "auto"
186-
z_label = xarray.plot.utils.label_from_attrs(img).replace("\n", "<br>")
187-
188-
if labels is not None:
189-
if "x" in labels:
190-
x_label = labels["x"]
191-
if "y" in labels:
192-
y_label = labels["y"]
193-
if "color" in labels:
194-
z_label = labels["color"]
195-
196-
if not img_is_xarray:
183+
if aspect is None:
184+
aspect = "auto"
185+
if labels.get("x", None) is None:
186+
labels["x"] = x_label
187+
if labels.get("y", None) is None:
188+
labels["y"] = y_label
189+
if labels.get("color", None) is None:
190+
labels["color"] = xarray.plot.utils.label_from_attrs(img)
191+
labels["color"] = labels["color"].replace("\n", "<br>")
192+
else:
193+
if labels.get("x", None) is None:
194+
labels["x"] = ""
195+
if labels.get("y", None) is None:
196+
labels["y"] = ""
197+
if labels.get("color", None) is None:
198+
labels["color"] = ""
197199
if aspect is None:
198200
aspect = "equal"
199201

@@ -205,7 +207,7 @@ def imshow(
205207

206208
# For 2d data, use Heatmap trace
207209
if img.ndim == 2:
208-
trace = go.Heatmap(z=img, coloraxis="coloraxis1")
210+
trace = go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")
209211
autorange = True if origin == "lower" else "reversed"
210212
layout = dict(yaxis=dict(autorange=autorange))
211213
if aspect == "equal":
@@ -224,8 +226,9 @@ def imshow(
224226
cmid=color_continuous_midpoint,
225227
cmin=range_color[0],
226228
cmax=range_color[1],
227-
colorbar=dict(title=z_label),
228229
)
230+
if labels["color"]:
231+
layout["coloraxis1"]["colorbar"] = dict(title=labels["color"])
229232

230233
# For 2D+RGB data, use Image trace
231234
elif img.ndim == 3 and img.shape[-1] in [3, 4]:
@@ -250,19 +253,13 @@ def imshow(
250253
layout_patch["margin"] = {"t": 60}
251254
fig = go.Figure(data=trace, layout=layout)
252255
fig.update_layout(layout_patch)
253-
if img.ndim <= 2:
254-
hovertemplate = (
255-
x_label
256-
+ ": %{x} <br>"
257-
+ y_label
258-
+ ": %{y} <br>"
259-
+ z_label
260-
+ " : %{z}<extra></extra>"
261-
)
262-
fig.update_traces(hovertemplate=hovertemplate)
263-
if img_is_xarray:
264-
fig.update_traces(x=x, y=y)
265-
fig.update_xaxes(title_text=x_label)
266-
fig.update_yaxes(title_text=y_label)
256+
fig.update_traces(
257+
hovertemplate="%s: %%{x}<br>%s: %%{y}<br>%s: %%{z}<extra></extra>"
258+
% (labels["x"] or "x", labels["y"] or "y", labels["color"] or "color",)
259+
)
260+
if labels["x"]:
261+
fig.update_xaxes(title_text=labels["x"])
262+
if labels["y"]:
263+
fig.update_yaxes(title_text=labels["y"])
267264
fig.update_layout(template=args["template"], overwrite=True)
268265
return fig

0 commit comments

Comments
 (0)