From fa6bb0f4511297ca2a6388aee8247814b41a524a Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 10 Sep 2020 09:52:06 +0200 Subject: [PATCH 1/5] take x and y into account when using Image trace --- packages/python/plotly/plotly/express/_imshow.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 218a824e154..3b46784ee60 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -403,6 +403,13 @@ def imshow( _vectorize_zvalue(zmin, mode="min"), _vectorize_zvalue(zmax, mode="max"), ) + x0, y0, dx, dy = (None,) * 4 + if x is not None: + x0 = x[0] + dx = x[1] - x[0] + if y is not None: + y0 = y[0] + dy = y[1] - y[0] if binary_string: if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img @@ -428,10 +435,10 @@ def imshow( compression=binary_compression_level, ext=binary_format, ) - trace = go.Image(source=img_str) + trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy) else: colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" - trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel) + trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel, x0=x0, y0=y0, dx=dx, dy=dy) layout = {} if origin == "lower": layout["yaxis"] = dict(autorange=True) From 14fc33af583444e654b994e2d9a65de00d784f6a Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 10 Sep 2020 14:44:08 +0200 Subject: [PATCH 2/5] x and y parameters are now used for Image trace in imshow --- .../python/plotly/plotly/express/_imshow.py | 37 ++++++++++++------- .../tests/test_core/test_px/test_imshow.py | 10 +++++ 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 3b46784ee60..da43934212e 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -279,21 +279,15 @@ def imshow( labels = labels.copy() # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): - if binary_string: - raise ValueError( - "It is not possible to use binary image strings for xarrays." - "Please pass your data as a numpy array instead using" - "`img.values`" - ) y_label, x_label = img.dims[0], img.dims[1] # np.datetime64 is not handled correctly by go.Heatmap for ax in [x_label, y_label]: if np.issubdtype(img.coords[ax].dtype, np.datetime64): img.coords[ax] = img.coords[ax].astype(str) if x is None: - x = img.coords[x_label] + x = img.coords[x_label].values if y is None: - y = img.coords[y_label] + y = img.coords[y_label].values if aspect is None: aspect = "auto" if labels.get("x", None) is None: @@ -405,11 +399,15 @@ def imshow( ) x0, y0, dx, dy = (None,) * 4 if x is not None: - x0 = x[0] - dx = x[1] - x[0] + x = np.asanyarray(x) + if np.issubdtype(x.dtype, np.number): + x0 = x[0] + dx = x[1] - x[0] if y is not None: - y0 = y[0] - dy = y[1] - y[0] + y = np.asanyarray(y) + if np.issubdtype(y.dtype, np.number): + y0 = y[0] + dy = y[1] - y[0] if binary_string: if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img @@ -438,10 +436,21 @@ def imshow( trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy) else: colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" - trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel, x0=x0, y0=y0, dx=dx, dy=dy) + trace = go.Image( + z=img, + zmin=zmin, + zmax=zmax, + colormodel=colormodel, + x0=x0, + y0=y0, + dx=dx, + dy=dy, + ) layout = {} - if origin == "lower": + if origin == "lower" or (dy is not None and dy < 0): layout["yaxis"] = dict(autorange=True) + if dx is not None and dx < 0: + layout["xaxis"] = dict(autorange="reversed") else: raise ValueError( "px.imshow only accepts 2D single-channel, RGB or RGBA images. " diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 84e39c78330..02429dffb1e 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -205,6 +205,16 @@ def test_imshow_labels_and_ranges(): fig = px.imshow([[1, 2], [3, 4], [5, 6]], x=["a"]) +def test_imshow_ranges_image_trace(): + fig = px.imshow(img_rgb, x=[1, 11, 21]) + assert fig.data[0].dx == 10 + assert fig.data[0].x0 == 1 + fig = px.imshow(img_rgb, x=[21, 11, 1]) + assert fig.data[0].dx == -10 + assert fig.data[0].x0 == 21 + assert fig.layout.xaxis.autorange == "reversed" + + def test_imshow_dataframe(): df = px.data.medals_wide(indexed=False) fig = px.imshow(df) From f7ec141759e67726806574f62672748c80c4c171 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Nov 2020 16:09:25 +0100 Subject: [PATCH 3/5] raise ValueError when x and y don't have numerical dtype for Image trace --- .../python/plotly/plotly/express/_imshow.py | 10 +++++++++ .../tests/test_core/test_px/test_imshow.py | 22 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index da43934212e..478e362f5ef 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -403,11 +403,21 @@ def imshow( if np.issubdtype(x.dtype, np.number): x0 = x[0] dx = x[1] - x[0] + else: + raise ValueError( + "Only numerical values are accepted for the `x` parameter " + "when an Image trace is used." + ) if y is not None: y = np.asanyarray(y) if np.issubdtype(y.dtype, np.number): y0 = y[0] dy = y[1] - y[0] + else: + raise ValueError( + "Only numerical values are accepted for the `y` parameter " + "when an Image trace is used." + ) if binary_string: if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 02429dffb1e..313267aacbd 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -5,6 +5,7 @@ from PIL import Image from io import BytesIO import base64 +import datetime from plotly.express.imshow_utils import rescale_intensity img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8) @@ -204,6 +205,27 @@ def test_imshow_labels_and_ranges(): with pytest.raises(ValueError): fig = px.imshow([[1, 2], [3, 4], [5, 6]], x=["a"]) + img = np.ones((2, 2), dtype=np.uint8) + fig = px.imshow(img, x=["a", "b"]) + assert fig.data[0].x == ("a", "b") + + with pytest.raises(ValueError): + img = np.ones((2, 2, 3), dtype=np.uint8) + fig = px.imshow(img, x=["a", "b"]) + + img = np.ones((2, 2), dtype=np.uint8) + base = datetime.datetime(2000, 1, 1) + fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)]) + assert fig.data[0].x == ( + datetime.datetime(2000, 1, 1, 0, 0), + datetime.datetime(2000, 1, 1, 1, 0), + ) + + with pytest.raises(ValueError): + img = np.ones((2, 2, 3), dtype=np.uint8) + base = datetime.datetime(2000, 1, 1) + fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)]) + def test_imshow_ranges_image_trace(): fig = px.imshow(img_rgb, x=[1, 11, 21]) From c7c66dc3c09139a097a72fd9b288bca4c2a4daf3 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Nov 2020 16:51:16 +0100 Subject: [PATCH 4/5] better error message --- .../python/plotly/plotly/express/_imshow.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 1d3132ef2d4..bf036959def 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -204,8 +204,10 @@ def imshow( args = locals() apply_default_cascade(args) labels = labels.copy() + img_is_xarray = False # ----- Define x and y, set labels if img is an xarray ------------------- if xarray_imported and isinstance(img, xarray.DataArray): + img_is_xarray = True y_label, x_label = img.dims[0], img.dims[1] # np.datetime64 is not handled correctly by go.Heatmap for ax in [x_label, y_label]: @@ -325,26 +327,33 @@ def imshow( _vectorize_zvalue(zmax, mode="max"), ) x0, y0, dx, dy = (None,) * 4 + error_msg_xarray = ( + "Non-numerical coordinates were passed with xarray `img`, but " + "the Image trace cannot handle it. Please use `binary_string=False` " + "for 2D data or pass instead the numpy array `img.values` to `px.imshow`." + ) if x is not None: x = np.asanyarray(x) if np.issubdtype(x.dtype, np.number): x0 = x[0] dx = x[1] - x[0] else: - raise ValueError( + error_msg = error_msg_xarray if img_is_xarray else ( "Only numerical values are accepted for the `x` parameter " "when an Image trace is used." - ) + ) + raise ValueError(error_msg) if y is not None: y = np.asanyarray(y) if np.issubdtype(y.dtype, np.number): y0 = y[0] dy = y[1] - y[0] else: - raise ValueError( + error_msg = error_msg_xarray if img_is_xarray else ( "Only numerical values are accepted for the `y` parameter " "when an Image trace is used." - ) + ) + raise ValueError(error_msg) if binary_string: if zmin is None and zmax is None: # no rescaling, faster img_rescaled = img From fc0c86c0b04a8b02c00cf9904873d629a6fef0cd Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Nov 2020 16:56:26 +0100 Subject: [PATCH 5/5] black --- .../python/plotly/plotly/express/_imshow.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index bf036959def..88713e54368 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -328,20 +328,24 @@ def imshow( ) x0, y0, dx, dy = (None,) * 4 error_msg_xarray = ( - "Non-numerical coordinates were passed with xarray `img`, but " - "the Image trace cannot handle it. Please use `binary_string=False` " - "for 2D data or pass instead the numpy array `img.values` to `px.imshow`." - ) + "Non-numerical coordinates were passed with xarray `img`, but " + "the Image trace cannot handle it. Please use `binary_string=False` " + "for 2D data or pass instead the numpy array `img.values` to `px.imshow`." + ) if x is not None: x = np.asanyarray(x) if np.issubdtype(x.dtype, np.number): x0 = x[0] dx = x[1] - x[0] else: - error_msg = error_msg_xarray if img_is_xarray else ( - "Only numerical values are accepted for the `x` parameter " - "when an Image trace is used." - ) + error_msg = ( + error_msg_xarray + if img_is_xarray + else ( + "Only numerical values are accepted for the `x` parameter " + "when an Image trace is used." + ) + ) raise ValueError(error_msg) if y is not None: y = np.asanyarray(y) @@ -349,10 +353,14 @@ def imshow( y0 = y[0] dy = y[1] - y[0] else: - error_msg = error_msg_xarray if img_is_xarray else ( - "Only numerical values are accepted for the `y` parameter " - "when an Image trace is used." - ) + error_msg = ( + error_msg_xarray + if img_is_xarray + else ( + "Only numerical values are accepted for the `y` parameter " + "when an Image trace is used." + ) + ) raise ValueError(error_msg) if binary_string: if zmin is None and zmax is None: # no rescaling, faster