Skip to content

Use x and y parameters for Image trace in imshow (for RGB or binary_string=True) #2761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,21 +279,15 @@ def imshow(
labels = labels.copy()
# ----- Define x and y, set labels if img is an xarray -------------------
if xarray_imported and isinstance(img, xarray.DataArray):
if binary_string:
raise ValueError(
"It is not possible to use binary image strings for xarrays."
"Please pass your data as a numpy array instead using"
"`img.values`"
)
y_label, x_label = img.dims[0], img.dims[1]
# np.datetime64 is not handled correctly by go.Heatmap
for ax in [x_label, y_label]:
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
img.coords[ax] = img.coords[ax].astype(str)
if x is None:
x = img.coords[x_label]
x = img.coords[x_label].values
if y is None:
y = img.coords[y_label]
y = img.coords[y_label].values
if aspect is None:
aspect = "auto"
if labels.get("x", None) is None:
Expand Down Expand Up @@ -403,6 +397,27 @@ def imshow(
_vectorize_zvalue(zmin, mode="min"),
_vectorize_zvalue(zmax, mode="max"),
)
x0, y0, dx, dy = (None,) * 4
if x is not None:
x = np.asanyarray(x)
if np.issubdtype(x.dtype, np.number):
x0 = x[0]
dx = x[1] - x[0]
else:
raise ValueError(
"Only numerical values are accepted for the `x` parameter "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is correct but I think a bit opaque for users? in the xarray case they've maybe not specified x (as it's implicit) and they don't know they're using an Image trace (all they know is they set binary_string=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I can think of a better error message, but do we agree that we should also error when x is passed explicitly and does not have a numerical dtype? (otherwise it would just be ignored, which I agree is not a good thing).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the error message, is it better now?

"when an Image trace is used."
)
if y is not None:
y = np.asanyarray(y)
if np.issubdtype(y.dtype, np.number):
y0 = y[0]
dy = y[1] - y[0]
else:
raise ValueError(
"Only numerical values are accepted for the `y` parameter "
"when an Image trace is used."
)
if binary_string:
if zmin is None and zmax is None: # no rescaling, faster
img_rescaled = img
Expand All @@ -428,13 +443,24 @@ def imshow(
compression=binary_compression_level,
ext=binary_format,
)
trace = go.Image(source=img_str)
trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy)
else:
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)
trace = go.Image(
z=img,
zmin=zmin,
zmax=zmax,
colormodel=colormodel,
x0=x0,
y0=y0,
dx=dx,
dy=dy,
)
layout = {}
if origin == "lower":
if origin == "lower" or (dy is not None and dy < 0):
layout["yaxis"] = dict(autorange=True)
if dx is not None and dx < 0:
layout["xaxis"] = dict(autorange="reversed")
else:
raise ValueError(
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from PIL import Image
from io import BytesIO
import base64
import datetime
from plotly.express.imshow_utils import rescale_intensity

img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8)
Expand Down Expand Up @@ -204,6 +205,37 @@ def test_imshow_labels_and_ranges():
with pytest.raises(ValueError):
fig = px.imshow([[1, 2], [3, 4], [5, 6]], x=["a"])

img = np.ones((2, 2), dtype=np.uint8)
fig = px.imshow(img, x=["a", "b"])
assert fig.data[0].x == ("a", "b")

with pytest.raises(ValueError):
img = np.ones((2, 2, 3), dtype=np.uint8)
fig = px.imshow(img, x=["a", "b"])

img = np.ones((2, 2), dtype=np.uint8)
base = datetime.datetime(2000, 1, 1)
fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)])
assert fig.data[0].x == (
datetime.datetime(2000, 1, 1, 0, 0),
datetime.datetime(2000, 1, 1, 1, 0),
)

with pytest.raises(ValueError):
img = np.ones((2, 2, 3), dtype=np.uint8)
base = datetime.datetime(2000, 1, 1)
fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)])


def test_imshow_ranges_image_trace():
fig = px.imshow(img_rgb, x=[1, 11, 21])
assert fig.data[0].dx == 10
assert fig.data[0].x0 == 1
fig = px.imshow(img_rgb, x=[21, 11, 1])
assert fig.data[0].dx == -10
assert fig.data[0].x0 == 21
assert fig.layout.xaxis.autorange == "reversed"


def test_imshow_dataframe():
df = px.data.medals_wide(indexed=False)
Expand Down