@@ -286,7 +286,7 @@ def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=(8.0, 6.0)
286
286
axes_num : number
287
287
expected number of axes. Unnecessary axes should be set to invisible.
288
288
layout : tuple
289
- expected layout
289
+ expected layout, (expected number of rows , columns)
290
290
figsize : tuple
291
291
expected figsize. default is matplotlib default
292
292
"""
@@ -299,17 +299,22 @@ def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=(8.0, 6.0)
299
299
self .assertTrue (len (ax .get_children ()) > 0 )
300
300
301
301
if layout is not None :
302
- if isinstance (axes , list ):
303
- self .assertEqual ((len (axes ), ), layout )
304
- elif isinstance (axes , np .ndarray ):
305
- self .assertEqual (axes .shape , layout )
306
- else :
307
- # in case of AxesSubplot
308
- self .assertEqual ((1 , ), layout )
302
+ result = self ._get_axes_layout (plotting ._flatten (axes ))
303
+ self .assertEqual (result , layout )
309
304
310
305
self .assert_numpy_array_equal (np .round (visible_axes [0 ].figure .get_size_inches ()),
311
306
np .array (figsize ))
312
307
308
+ def _get_axes_layout (self , axes ):
309
+ x_set = set ()
310
+ y_set = set ()
311
+ for ax in axes :
312
+ # check axes coordinates to estimate layout
313
+ points = ax .get_position ().get_points ()
314
+ x_set .add (points [0 ][0 ])
315
+ y_set .add (points [0 ][1 ])
316
+ return (len (y_set ), len (x_set ))
317
+
313
318
def _flatten_visible (self , axes ):
314
319
"""
315
320
Flatten axes, and filter only visible
@@ -401,14 +406,14 @@ def test_plot(self):
401
406
402
407
# GH 6951
403
408
ax = _check_plot_works (self .ts .plot , subplots = True )
404
- self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , ))
409
+ self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , 1 ))
405
410
406
411
@slow
407
412
def test_plot_figsize_and_title (self ):
408
413
# figsize and title
409
414
ax = self .series .plot (title = 'Test' , figsize = (16 , 8 ))
410
415
self ._check_text_labels (ax .title , 'Test' )
411
- self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , ), figsize = (16 , 8 ))
416
+ self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , 1 ), figsize = (16 , 8 ))
412
417
413
418
def test_ts_area_lim (self ):
414
419
ax = self .ts .plot (kind = 'area' , stacked = False )
@@ -556,10 +561,10 @@ def test_hist_layout_with_by(self):
556
561
df = self .hist_df
557
562
558
563
axes = _check_plot_works (df .height .hist , by = df .gender , layout = (2 , 1 ))
559
- self ._check_axes_shape (axes , axes_num = 2 , layout = (2 , ), figsize = (10 , 5 ))
564
+ self ._check_axes_shape (axes , axes_num = 2 , layout = (2 , 1 ), figsize = (10 , 5 ))
560
565
561
566
axes = _check_plot_works (df .height .hist , by = df .category , layout = (4 , 1 ))
562
- self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , ), figsize = (10 , 5 ))
567
+ self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , 1 ), figsize = (10 , 5 ))
563
568
564
569
axes = _check_plot_works (df .height .hist , by = df .classroom , layout = (2 , 2 ))
565
570
self ._check_axes_shape (axes , axes_num = 3 , layout = (2 , 2 ), figsize = (10 , 5 ))
@@ -757,9 +762,9 @@ def test_plot(self):
757
762
df = self .tdf
758
763
_check_plot_works (df .plot , grid = False )
759
764
axes = _check_plot_works (df .plot , subplots = True )
760
- self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , ))
765
+ self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , 1 ))
761
766
_check_plot_works (df .plot , subplots = True , use_index = False )
762
- self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , ))
767
+ self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , 1 ))
763
768
764
769
df = DataFrame ({'x' : [1 , 2 ], 'y' : [3 , 4 ]})
765
770
with tm .assertRaises (TypeError ):
@@ -774,7 +779,7 @@ def test_plot(self):
774
779
_check_plot_works (df .plot , ylim = (- 100 , 100 ), xlim = (- 100 , 100 ))
775
780
776
781
axes = _check_plot_works (df .plot , subplots = True , title = 'blah' )
777
- self ._check_axes_shape (axes , axes_num = 3 , layout = (3 , ))
782
+ self ._check_axes_shape (axes , axes_num = 3 , layout = (3 , 1 ))
778
783
779
784
_check_plot_works (df .plot , title = 'blah' )
780
785
@@ -804,7 +809,7 @@ def test_plot(self):
804
809
# Test with single column
805
810
df = DataFrame ({'x' : np .random .rand (10 )})
806
811
axes = _check_plot_works (df .plot , kind = 'bar' , subplots = True )
807
- self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , ))
812
+ self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , 1 ))
808
813
809
814
def test_nonnumeric_exclude (self ):
810
815
df = DataFrame ({'A' : ["x" , "y" , "z" ], 'B' : [1 , 2 , 3 ]})
@@ -846,7 +851,7 @@ def test_plot_xy(self):
846
851
# figsize and title
847
852
ax = df .plot (x = 1 , y = 2 , title = 'Test' , figsize = (16 , 8 ))
848
853
self ._check_text_labels (ax .title , 'Test' )
849
- self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , ), figsize = (16. , 8. ))
854
+ self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , 1 ), figsize = (16. , 8. ))
850
855
851
856
# columns.inferred_type == 'mixed'
852
857
# TODO add MultiIndex test
@@ -913,7 +918,7 @@ def test_subplots(self):
913
918
914
919
for kind in ['bar' , 'barh' , 'line' , 'area' ]:
915
920
axes = df .plot (kind = kind , subplots = True , sharex = True , legend = True )
916
- self ._check_axes_shape (axes , axes_num = 3 , layout = (3 , ))
921
+ self ._check_axes_shape (axes , axes_num = 3 , layout = (3 , 1 ))
917
922
918
923
for ax , column in zip (axes , df .columns ):
919
924
self ._check_legend_labels (ax , labels = [com .pprint_thing (column )])
@@ -1081,7 +1086,7 @@ def test_bar_linewidth(self):
1081
1086
1082
1087
# subplots
1083
1088
axes = df .plot (kind = 'bar' , linewidth = 2 , subplots = True )
1084
- self ._check_axes_shape (axes , axes_num = 5 , layout = (5 , ))
1089
+ self ._check_axes_shape (axes , axes_num = 5 , layout = (5 , 1 ))
1085
1090
for ax in axes :
1086
1091
for r in ax .patches :
1087
1092
self .assertEqual (r .get_linewidth (), 2 )
@@ -1179,7 +1184,7 @@ def test_plot_scatter(self):
1179
1184
1180
1185
# GH 6951
1181
1186
axes = df .plot (x = 'x' , y = 'y' , kind = 'scatter' , subplots = True )
1182
- self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , ))
1187
+ self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , 1 ))
1183
1188
1184
1189
@slow
1185
1190
def test_plot_bar (self ):
@@ -1486,7 +1491,7 @@ def test_kde(self):
1486
1491
self ._check_legend_labels (ax , labels = expected )
1487
1492
1488
1493
axes = _check_plot_works (df .plot , kind = 'kde' , subplots = True )
1489
- self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , ))
1494
+ self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , 1 ))
1490
1495
1491
1496
axes = df .plot (kind = 'kde' , logy = True , subplots = True )
1492
1497
self ._check_ax_scales (axes , yaxis = 'log' )
@@ -1949,8 +1954,7 @@ def test_hexbin_basic(self):
1949
1954
# hexbin should have 2 axes in the figure, 1 for plotting and another is colorbar
1950
1955
self .assertEqual (len (axes [0 ].figure .axes ), 2 )
1951
1956
# return value is single axes
1952
- self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , ))
1953
-
1957
+ self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , 1 ))
1954
1958
1955
1959
@slow
1956
1960
def test_hexbin_with_c (self ):
@@ -2193,31 +2197,31 @@ class TestDataFrameGroupByPlots(TestPlotBase):
2193
2197
def test_boxplot (self ):
2194
2198
grouped = self .hist_df .groupby (by = 'gender' )
2195
2199
box = _check_plot_works (grouped .boxplot , return_type = 'dict' )
2196
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 )
2200
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 , layout = ( 1 , 2 ) )
2197
2201
2198
2202
box = _check_plot_works (grouped .boxplot , subplots = False ,
2199
2203
return_type = 'dict' )
2200
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 )
2204
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 , layout = ( 1 , 2 ) )
2201
2205
2202
2206
tuples = lzip (string .ascii_letters [:10 ], range (10 ))
2203
2207
df = DataFrame (np .random .rand (10 , 3 ),
2204
2208
index = MultiIndex .from_tuples (tuples ))
2205
2209
2206
2210
grouped = df .groupby (level = 1 )
2207
2211
box = _check_plot_works (grouped .boxplot , return_type = 'dict' )
2208
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 10 )
2212
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 10 , layout = ( 4 , 3 ) )
2209
2213
2210
2214
box = _check_plot_works (grouped .boxplot , subplots = False ,
2211
2215
return_type = 'dict' )
2212
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 10 )
2216
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 10 , layout = ( 4 , 3 ) )
2213
2217
2214
2218
grouped = df .unstack (level = 1 ).groupby (level = 0 , axis = 1 )
2215
2219
box = _check_plot_works (grouped .boxplot , return_type = 'dict' )
2216
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2220
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 2 , 2 ) )
2217
2221
2218
2222
box = _check_plot_works (grouped .boxplot , subplots = False ,
2219
2223
return_type = 'dict' )
2220
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2224
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 2 , 2 ) )
2221
2225
2222
2226
def test_series_plot_color_kwargs (self ):
2223
2227
# GH1890
@@ -2327,35 +2331,35 @@ def test_grouped_box_layout(self):
2327
2331
2328
2332
box = _check_plot_works (df .groupby ('gender' ).boxplot , column = 'height' ,
2329
2333
return_type = 'dict' )
2330
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 )
2334
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 , layout = ( 1 , 2 ) )
2331
2335
2332
2336
box = _check_plot_works (df .groupby ('category' ).boxplot , column = 'height' ,
2333
2337
return_type = 'dict' )
2334
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 4 )
2338
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 4 , layout = ( 2 , 2 ) )
2335
2339
2336
2340
# GH 6769
2337
2341
box = _check_plot_works (df .groupby ('classroom' ).boxplot ,
2338
2342
column = 'height' , return_type = 'dict' )
2339
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2343
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 2 , 2 ) )
2340
2344
2341
2345
box = df .boxplot (column = ['height' , 'weight' , 'category' ], by = 'gender' )
2342
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2346
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 2 , 2 ) )
2343
2347
2344
2348
box = df .groupby ('classroom' ).boxplot (
2345
2349
column = ['height' , 'weight' , 'category' ], return_type = 'dict' )
2346
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2350
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 2 , 2 ) )
2347
2351
2348
2352
box = _check_plot_works (df .groupby ('category' ).boxplot , column = 'height' ,
2349
2353
layout = (3 , 2 ), return_type = 'dict' )
2350
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 4 )
2354
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 4 , layout = ( 3 , 2 ) )
2351
2355
2352
2356
box = df .boxplot (column = ['height' , 'weight' , 'category' ], by = 'gender' , layout = (4 , 1 ))
2353
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2357
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 4 , 1 ) )
2354
2358
2355
2359
box = df .groupby ('classroom' ).boxplot (
2356
2360
column = ['height' , 'weight' , 'category' ], layout = (1 , 4 ),
2357
2361
return_type = 'dict' )
2358
- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2362
+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 1 , 4 ) )
2359
2363
2360
2364
@slow
2361
2365
def test_grouped_hist_layout (self ):
@@ -2367,10 +2371,10 @@ def test_grouped_hist_layout(self):
2367
2371
layout = (1 , 3 ))
2368
2372
2369
2373
axes = _check_plot_works (df .hist , column = 'height' , by = df .gender , layout = (2 , 1 ))
2370
- self ._check_axes_shape (axes , axes_num = 2 , layout = (2 , ), figsize = (10 , 5 ))
2374
+ self ._check_axes_shape (axes , axes_num = 2 , layout = (2 , 1 ), figsize = (10 , 5 ))
2371
2375
2372
2376
axes = _check_plot_works (df .hist , column = 'height' , by = df .category , layout = (4 , 1 ))
2373
- self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , ), figsize = (10 , 5 ))
2377
+ self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , 1 ), figsize = (10 , 5 ))
2374
2378
2375
2379
axes = _check_plot_works (df .hist , column = 'height' , by = df .category ,
2376
2380
layout = (4 , 2 ), figsize = (12 , 8 ))
0 commit comments