Skip to content

Commit cf644e5

Browse files
committed
added doc
1 parent ba65990 commit cf644e5

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

doc/python/imshow.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ fig.show()
403403

404404
*Introduced in plotly 4.11*
405405

406-
For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axes the image is sliced through to make the facets. With `facet_col_wrap` , one can set the maximum number of columns.
406+
For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axes the image is sliced through to make the facets. With `facet_col_wrap` , one can set the maximum number of columns. For image datasets passed as xarrays, it is also possible to give an axis name as a string for `facet_col`.
407407

408408
It is recommended to use `binary_string=True` for facetted plots of images in order to keep a small figure size and a short rendering time.
409409

@@ -455,12 +455,14 @@ fig.show()
455455

456456
*Introduced in plotly 4.11*
457457

458+
For xarray datasets, one can pass either an axis number or an axis name to `animation_frame`. Axis names and coordinates are automatically used for the labels, ticks and animation controls of the figure.
459+
458460
```python
459461
import plotly.express as px
460462
import xarray as xr
461463
# Load xarray from dataset included in the xarray tutorial
462464
ds = xr.tutorial.open_dataset('air_temperature').air[:20]
463-
fig = px.imshow(ds, animation_frame='lat', color_continuous_scale='RdBu_r')
465+
fig = px.imshow(ds, animation_frame='time', zmin=220, zmax=300, color_continuous_scale='RdBu_r')
464466
fig.show()
465467
```
466468

packages/python/plotly/plotly/express/_imshow.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -288,23 +288,23 @@ def imshow(
288288
args = locals()
289289
apply_default_cascade(args)
290290
labels = labels.copy()
291-
col_labels = []
291+
nslices = 1
292292
if facet_col is not None:
293293
if isinstance(facet_col, str):
294294
facet_col = img.dims.index(facet_col)
295295
nslices = img.shape[facet_col]
296296
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices
297297
nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
298-
col_labels = ["plane = %d" % i for i in range(nslices)]
299298
else:
300299
nrows = 1
301300
ncols = 1
302301
if animation_frame is not None:
303302
if isinstance(animation_frame, str):
304303
animation_frame = img.dims.index(animation_frame)
304+
nslices = img.shape[animation_frame]
305305
slice_through = (facet_col is not None) or (animation_frame is not None)
306-
plane_label = None
307-
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
306+
slice_label = None
307+
slices = range(nslices)
308308
# ----- Define x and y, set labels if img is an xarray -------------------
309309
if xarray_imported and isinstance(img, xarray.DataArray):
310310
# if binary_string:
@@ -314,13 +314,12 @@ def imshow(
314314
# "`img.values`"
315315
# )
316316
dims = list(img.dims)
317-
print(dims)
318317
if slice_through:
319318
slice_index = facet_col if facet_col is not None else animation_frame
319+
slices = img.coords[img.dims[slice_index]].values
320320
_ = dims.pop(slice_index)
321-
plane_label = img.dims[slice_index]
321+
slice_label = img.dims[slice_index]
322322
y_label, x_label = dims[0], dims[1]
323-
print(y_label, x_label)
324323
# np.datetime64 is not handled correctly by go.Heatmap
325324
for ax in [x_label, y_label]:
326325
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
@@ -335,8 +334,8 @@ def imshow(
335334
labels["x"] = x_label
336335
if labels.get("y", None) is None:
337336
labels["y"] = y_label
338-
if labels.get("plane", None) is None:
339-
labels["plane"] = plane_label
337+
if labels.get("slice", None) is None:
338+
labels["slice"] = slice_label
340339
if labels.get("color", None) is None:
341340
labels["color"] = xarray.plot.utils.label_from_attrs(img)
342341
labels["color"] = labels["color"].replace("\n", "<br>")
@@ -378,7 +377,7 @@ def imshow(
378377
img = np.moveaxis(img, animation_frame, 0)
379378
animation_frame = True
380379
args["animation_frame"] = (
381-
"plane" if labels.get("plane") is None else labels["plane"]
380+
"slice" if labels.get("slice") is None else labels["slice"]
382381
)
383382

384383
# Default behaviour of binary_string: True for RGB images, False for 2D
@@ -531,6 +530,14 @@ def imshow(
531530
% str(img.shape)
532531
)
533532

533+
# Now build figure
534+
col_labels = []
535+
if facet_col is not None:
536+
slice_label = "slice" if labels.get("slice") is None else labels["slice"]
537+
if slices is None:
538+
slices = range(nslices)
539+
col_labels = ["%s = %d" % (slice_label, i) for i in slices]
540+
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
534541
layout_patch = dict()
535542
for attr_name in ["height", "width"]:
536543
if args[attr_name]:
@@ -541,11 +548,11 @@ def imshow(
541548
layout_patch["margin"] = {"t": 60}
542549

543550
frame_list = []
544-
for index, trace in enumerate(traces):
551+
for index, (slice_index, trace) in enumerate(zip(slices, traces)):
545552
if facet_col or index == 0:
546553
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
547554
if animation_frame:
548-
frame_list.append(dict(data=trace, layout=layout, name=str(index)))
555+
frame_list.append(dict(data=trace, layout=layout, name=str(slice_index)))
549556
if animation_frame:
550557
fig.frames = frame_list
551558
fig.update_layout(layout)

0 commit comments

Comments
 (0)