diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 9c12ae575c6..88713e54368 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -204,23 +204,19 @@ def imshow( args = locals() apply_default_cascade(args) labels = labels.copy() + img_is_xarray = False # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): - if binary_string: - raise ValueError( - "It is not possible to use binary image strings for xarrays." - "Please pass your data as a numpy array instead using" - "`img.values`" - ) + img_is_xarray = True y_label, x_label = img.dims[0], img.dims[1] # np.datetime64 is not handled correctly by go.Heatmap for ax in [x_label, y_label]: if np.issubdtype(img.coords[ax].dtype, np.datetime64): img.coords[ax] = img.coords[ax].astype(str) if x is None: - x = img.coords[x_label] + x = img.coords[x_label].values if y is None: - y = img.coords[y_label] + y = img.coords[y_label].values if aspect is None: aspect = "auto" if labels.get("x", None) is None: @@ -330,6 +326,42 @@ def imshow( _vectorize_zvalue(zmin, mode="min"), _vectorize_zvalue(zmax, mode="max"), ) + x0, y0, dx, dy = (None,) * 4 + error_msg_xarray = ( + "Non-numerical coordinates were passed with xarray `img`, but " + "the Image trace cannot handle it. Please use `binary_string=False` " + "for 2D data or pass instead the numpy array `img.values` to `px.imshow`." + ) + if x is not None: + x = np.asanyarray(x) + if np.issubdtype(x.dtype, np.number): + x0 = x[0] + dx = x[1] - x[0] + else: + error_msg = ( + error_msg_xarray + if img_is_xarray + else ( + "Only numerical values are accepted for the `x` parameter " + "when an Image trace is used." + ) + ) + raise ValueError(error_msg) + if y is not None: + y = np.asanyarray(y) + if np.issubdtype(y.dtype, np.number): + y0 = y[0] + dy = y[1] - y[0] + else: + error_msg = ( + error_msg_xarray + if img_is_xarray + else ( + "Only numerical values are accepted for the `y` parameter " + "when an Image trace is used." + ) + ) + raise ValueError(error_msg) if binary_string: if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img @@ -355,13 +387,24 @@ def imshow( compression=binary_compression_level, ext=binary_format, ) - trace = go.Image(source=img_str) + trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy) else: colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" - trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel) + trace = go.Image( + z=img, + zmin=zmin, + zmax=zmax, + colormodel=colormodel, + x0=x0, + y0=y0, + dx=dx, + dy=dy, + ) layout = {} - if origin == "lower": + if origin == "lower" or (dy is not None and dy < 0): layout["yaxis"] = dict(autorange=True) + if dx is not None and dx < 0: + layout["xaxis"] = dict(autorange="reversed") else: raise ValueError( "px.imshow only accepts 2D single-channel, RGB or RGBA images. " diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 84e39c78330..313267aacbd 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -5,6 +5,7 @@ from PIL import Image from io import BytesIO import base64 +import datetime from plotly.express.imshow_utils import rescale_intensity img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8) @@ -204,6 +205,37 @@ def test_imshow_labels_and_ranges(): with pytest.raises(ValueError): fig = px.imshow([[1, 2], [3, 4], [5, 6]], x=["a"]) + img = np.ones((2, 2), dtype=np.uint8) + fig = px.imshow(img, x=["a", "b"]) + assert fig.data[0].x == ("a", "b") + + with pytest.raises(ValueError): + img = np.ones((2, 2, 3), dtype=np.uint8) + fig = px.imshow(img, x=["a", "b"]) + + img = np.ones((2, 2), dtype=np.uint8) + base = datetime.datetime(2000, 1, 1) + fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)]) + assert fig.data[0].x == ( + datetime.datetime(2000, 1, 1, 0, 0), + datetime.datetime(2000, 1, 1, 1, 0), + ) + + with pytest.raises(ValueError): + img = np.ones((2, 2, 3), dtype=np.uint8) + base = datetime.datetime(2000, 1, 1) + fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)]) + + +def test_imshow_ranges_image_trace(): + fig = px.imshow(img_rgb, x=[1, 11, 21]) + assert fig.data[0].dx == 10 + assert fig.data[0].x0 == 1 + fig = px.imshow(img_rgb, x=[21, 11, 1]) + assert fig.data[0].dx == -10 + assert fig.data[0].x0 == 21 + assert fig.layout.xaxis.autorange == "reversed" + def test_imshow_dataframe(): df = px.data.medals_wide(indexed=False)