@@ -1022,7 +1022,10 @@ def _post_plot_logic(self):
1022
1022
def _adorn_subplots (self ):
1023
1023
to_adorn = self .axes
1024
1024
1025
- # todo: sharex, sharey handling?
1025
+ if len (self .axes ) > 0 :
1026
+ all_axes = self ._get_axes ()
1027
+ nrows , ncols = self ._get_axes_layout ()
1028
+ _handle_shared_axes (all_axes , len (all_axes ), len (all_axes ), nrows , ncols , self .sharex , self .sharey )
1026
1029
1027
1030
for ax in to_adorn :
1028
1031
if self .yticks is not None :
@@ -1375,6 +1378,19 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
1375
1378
errors [kw ] = err
1376
1379
return errors
1377
1380
1381
+ def _get_axes (self ):
1382
+ return self .axes [0 ].get_figure ().get_axes ()
1383
+
1384
+ def _get_axes_layout (self ):
1385
+ axes = self ._get_axes ()
1386
+ x_set = set ()
1387
+ y_set = set ()
1388
+ for ax in axes :
1389
+ # check axes coordinates to estimate layout
1390
+ points = ax .get_position ().get_points ()
1391
+ x_set .add (points [0 ][0 ])
1392
+ y_set .add (points [0 ][1 ])
1393
+ return (len (y_set ), len (x_set ))
1378
1394
1379
1395
class ScatterPlot (MPLPlot ):
1380
1396
_layout_type = 'single'
@@ -3231,6 +3247,28 @@ def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True,
3231
3247
ax = fig .add_subplot (nrows , ncols , i + 1 , ** kwds )
3232
3248
axarr [i ] = ax
3233
3249
3250
+ _handle_shared_axes (axarr , nplots , naxes , nrows , ncols , sharex , sharey )
3251
+
3252
+ if naxes != nplots :
3253
+ for ax in axarr [naxes :]:
3254
+ ax .set_visible (False )
3255
+
3256
+ if squeeze :
3257
+ # Reshape the array to have the final desired dimension (nrow,ncol),
3258
+ # though discarding unneeded dimensions that equal 1. If we only have
3259
+ # one subplot, just return it instead of a 1-element array.
3260
+ if nplots == 1 :
3261
+ axes = axarr [0 ]
3262
+ else :
3263
+ axes = axarr .reshape (nrows , ncols ).squeeze ()
3264
+ else :
3265
+ # returned axis array will be always 2-d, even if nrows=ncols=1
3266
+ axes = axarr .reshape (nrows , ncols )
3267
+
3268
+ return fig , axes
3269
+
3270
+
3271
+ def _handle_shared_axes (axarr , nplots , naxes , nrows , ncols , sharex , sharey ):
3234
3272
if nplots > 1 :
3235
3273
3236
3274
if sharex and nrows > 1 :
@@ -3241,8 +3279,11 @@ def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True,
3241
3279
# set_visible will not be effective if
3242
3280
# minor axis has NullLocator and NullFormattor (default)
3243
3281
import matplotlib .ticker as ticker
3244
- ax .xaxis .set_minor_locator (ticker .AutoLocator ())
3245
- ax .xaxis .set_minor_formatter (ticker .FormatStrFormatter ('' ))
3282
+
3283
+ if isinstance (ax .xaxis .get_minor_locator (), ticker .NullLocator ):
3284
+ ax .xaxis .set_minor_locator (ticker .AutoLocator ())
3285
+ if isinstance (ax .xaxis .get_minor_formatter (), ticker .NullFormatter ):
3286
+ ax .xaxis .set_minor_formatter (ticker .FormatStrFormatter ('' ))
3246
3287
for label in ax .get_xticklabels (minor = True ):
3247
3288
label .set_visible (False )
3248
3289
except Exception : # pragma no cover
@@ -3255,32 +3296,16 @@ def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True,
3255
3296
label .set_visible (False )
3256
3297
try :
3257
3298
import matplotlib .ticker as ticker
3258
- ax .yaxis .set_minor_locator (ticker .AutoLocator ())
3259
- ax .yaxis .set_minor_formatter (ticker .FormatStrFormatter ('' ))
3299
+ if isinstance (ax .yaxis .get_minor_locator (), ticker .NullLocator ):
3300
+ ax .yaxis .set_minor_locator (ticker .AutoLocator ())
3301
+ if isinstance (ax .yaxis .get_minor_formatter (), ticker .NullFormatter ):
3302
+ ax .yaxis .set_minor_formatter (ticker .FormatStrFormatter ('' ))
3260
3303
for label in ax .get_yticklabels (minor = True ):
3261
3304
label .set_visible (False )
3262
3305
except Exception : # pragma no cover
3263
3306
pass
3264
3307
ax .yaxis .get_label ().set_visible (False )
3265
3308
3266
- if naxes != nplots :
3267
- for ax in axarr [naxes :]:
3268
- ax .set_visible (False )
3269
-
3270
- if squeeze :
3271
- # Reshape the array to have the final desired dimension (nrow,ncol),
3272
- # though discarding unneeded dimensions that equal 1. If we only have
3273
- # one subplot, just return it instead of a 1-element array.
3274
- if nplots == 1 :
3275
- axes = axarr [0 ]
3276
- else :
3277
- axes = axarr .reshape (nrows , ncols ).squeeze ()
3278
- else :
3279
- # returned axis array will be always 2-d, even if nrows=ncols=1
3280
- axes = axarr .reshape (nrows , ncols )
3281
-
3282
- return fig , axes
3283
-
3284
3309
3285
3310
def _flatten (axes ):
3286
3311
if not com .is_list_like (axes ):
0 commit comments