@@ -69,6 +69,8 @@ def imshow(
69
69
zmax = None ,
70
70
origin = None ,
71
71
labels = {},
72
+ x = None ,
73
+ y = None ,
72
74
color_continuous_scale = None ,
73
75
color_continuous_midpoint = None ,
74
76
range_color = None ,
@@ -167,33 +169,33 @@ def imshow(
167
169
"""
168
170
args = locals ()
169
171
apply_default_cascade (args )
170
- img_is_xarray = False
171
- x_label = "x"
172
- y_label = "y"
173
- z_label = ""
174
- if xarray_imported :
175
- if isinstance (img , xarray .DataArray ):
176
- y_label , x_label = img .dims [0 ], img .dims [1 ]
177
- # np.datetime64 is not handled correctly by go.Heatmap
178
- for ax in [x_label , y_label ]:
179
- if np .issubdtype (img .coords [ax ].dtype , np .datetime64 ):
180
- img .coords [ax ] = img .coords [ax ].astype (str )
172
+ labels = labels .copy ()
173
+ if xarray_imported and isinstance (img , xarray .DataArray ):
174
+ y_label , x_label = img .dims [0 ], img .dims [1 ]
175
+ # np.datetime64 is not handled correctly by go.Heatmap
176
+ for ax in [x_label , y_label ]:
177
+ if np .issubdtype (img .coords [ax ].dtype , np .datetime64 ):
178
+ img .coords [ax ] = img .coords [ax ].astype (str )
179
+ if x is None :
181
180
x = img .coords [x_label ]
181
+ if y is None :
182
182
y = img .coords [y_label ]
183
- img_is_xarray = True
184
- if aspect is None :
185
- aspect = "auto"
186
- z_label = xarray .plot .utils .label_from_attrs (img ).replace ("\n " , "<br>" )
187
-
188
- if labels is not None :
189
- if "x" in labels :
190
- x_label = labels ["x" ]
191
- if "y" in labels :
192
- y_label = labels ["y" ]
193
- if "color" in labels :
194
- z_label = labels ["color" ]
195
-
196
- if not img_is_xarray :
183
+ if aspect is None :
184
+ aspect = "auto"
185
+ if labels .get ("x" , None ) is None :
186
+ labels ["x" ] = x_label
187
+ if labels .get ("y" , None ) is None :
188
+ labels ["y" ] = y_label
189
+ if labels .get ("color" , None ) is None :
190
+ labels ["color" ] = xarray .plot .utils .label_from_attrs (img )
191
+ labels ["color" ] = labels ["color" ].replace ("\n " , "<br>" )
192
+ else :
193
+ if labels .get ("x" , None ) is None :
194
+ labels ["x" ] = ""
195
+ if labels .get ("y" , None ) is None :
196
+ labels ["y" ] = ""
197
+ if labels .get ("color" , None ) is None :
198
+ labels ["color" ] = ""
197
199
if aspect is None :
198
200
aspect = "equal"
199
201
@@ -205,7 +207,7 @@ def imshow(
205
207
206
208
# For 2d data, use Heatmap trace
207
209
if img .ndim == 2 :
208
- trace = go .Heatmap (z = img , coloraxis = "coloraxis1" )
210
+ trace = go .Heatmap (x = x , y = y , z = img , coloraxis = "coloraxis1" )
209
211
autorange = True if origin == "lower" else "reversed"
210
212
layout = dict (yaxis = dict (autorange = autorange ))
211
213
if aspect == "equal" :
@@ -224,8 +226,9 @@ def imshow(
224
226
cmid = color_continuous_midpoint ,
225
227
cmin = range_color [0 ],
226
228
cmax = range_color [1 ],
227
- colorbar = dict (title = z_label ),
228
229
)
230
+ if labels ["color" ]:
231
+ layout ["coloraxis1" ]["colorbar" ] = dict (title = labels ["color" ])
229
232
230
233
# For 2D+RGB data, use Image trace
231
234
elif img .ndim == 3 and img .shape [- 1 ] in [3 , 4 ]:
@@ -250,19 +253,13 @@ def imshow(
250
253
layout_patch ["margin" ] = {"t" : 60 }
251
254
fig = go .Figure (data = trace , layout = layout )
252
255
fig .update_layout (layout_patch )
253
- if img .ndim <= 2 :
254
- hovertemplate = (
255
- x_label
256
- + ": %{x} <br>"
257
- + y_label
258
- + ": %{y} <br>"
259
- + z_label
260
- + " : %{z}<extra></extra>"
261
- )
262
- fig .update_traces (hovertemplate = hovertemplate )
263
- if img_is_xarray :
264
- fig .update_traces (x = x , y = y )
265
- fig .update_xaxes (title_text = x_label )
266
- fig .update_yaxes (title_text = y_label )
256
+ fig .update_traces (
257
+ hovertemplate = "%s: %%{x}<br>%s: %%{y}<br>%s: %%{z}<extra></extra>"
258
+ % (labels ["x" ] or "x" , labels ["y" ] or "y" , labels ["color" ] or "color" ,)
259
+ )
260
+ if labels ["x" ]:
261
+ fig .update_xaxes (title_text = labels ["x" ])
262
+ if labels ["y" ]:
263
+ fig .update_yaxes (title_text = labels ["y" ])
267
264
fig .update_layout (template = args ["template" ], overwrite = True )
268
265
return fig
0 commit comments