@@ -490,29 +490,34 @@ def test_subplots(self):
490
490
df = DataFrame (np .random .rand (10 , 3 ),
491
491
index = list (string .ascii_letters [:10 ]))
492
492
493
- axes = df .plot (subplots = True , sharex = True , legend = True )
493
+ for kind in ['bar' , 'barh' , 'line' ]:
494
+ axes = df .plot (kind = kind , subplots = True , sharex = True , legend = True )
494
495
495
- for ax in axes :
496
- self .assertIsNotNone (ax .get_legend ())
497
-
498
- axes = df .plot (subplots = True , sharex = True )
499
- for ax in axes [:- 2 ]:
500
- [self .assert_ (not label .get_visible ())
501
- for label in ax .get_xticklabels ()]
502
- [self .assert_ (label .get_visible ())
503
- for label in ax .get_yticklabels ()]
496
+ for ax , column in zip (axes , df .columns ):
497
+ self ._check_legend_labels (ax , [column ])
504
498
505
- [self .assert_ (label .get_visible ())
506
- for label in axes [- 1 ].get_xticklabels ()]
507
- [self .assert_ (label .get_visible ())
508
- for label in axes [- 1 ].get_yticklabels ()]
499
+ axes = df .plot (kind = kind , subplots = True , sharex = True )
500
+ for ax in axes [:- 2 ]:
501
+ [self .assert_ (not label .get_visible ())
502
+ for label in ax .get_xticklabels ()]
503
+ [self .assert_ (label .get_visible ())
504
+ for label in ax .get_yticklabels ()]
509
505
510
- axes = df .plot (subplots = True , sharex = False )
511
- for ax in axes :
512
506
[self .assert_ (label .get_visible ())
513
- for label in ax .get_xticklabels ()]
507
+ for label in axes [ - 1 ] .get_xticklabels ()]
514
508
[self .assert_ (label .get_visible ())
515
- for label in ax .get_yticklabels ()]
509
+ for label in axes [- 1 ].get_yticklabels ()]
510
+
511
+ axes = df .plot (kind = kind , subplots = True , sharex = False )
512
+ for ax in axes :
513
+ [self .assert_ (label .get_visible ())
514
+ for label in ax .get_xticklabels ()]
515
+ [self .assert_ (label .get_visible ())
516
+ for label in ax .get_yticklabels ()]
517
+
518
+ axes = df .plot (kind = kind , subplots = True , legend = False )
519
+ for ax in axes :
520
+ self .assertTrue (ax .get_legend () is None )
516
521
517
522
@slow
518
523
def test_bar_colors (self ):
@@ -873,7 +878,7 @@ def test_kde(self):
873
878
_check_plot_works (df .plot , kind = 'kde' )
874
879
_check_plot_works (df .plot , kind = 'kde' , subplots = True )
875
880
ax = df .plot (kind = 'kde' )
876
- self .assertIsNotNone (ax . get_legend () )
881
+ self ._check_legend_labels (ax , df . columns )
877
882
axes = df .plot (kind = 'kde' , logy = True , subplots = True )
878
883
for ax in axes :
879
884
self .assertEqual (ax .get_yscale (), 'log' )
@@ -1046,6 +1051,64 @@ def test_plot_int_columns(self):
1046
1051
df = DataFrame (randn (100 , 4 )).cumsum ()
1047
1052
_check_plot_works (df .plot , legend = True )
1048
1053
1054
+ def _check_legend_labels (self , ax , labels ):
1055
+ import pandas .core .common as com
1056
+ labels = [com .pprint_thing (l ) for l in labels ]
1057
+ self .assertTrue (ax .get_legend () is not None )
1058
+ legend_labels = [t .get_text () for t in ax .get_legend ().get_texts ()]
1059
+ self .assertEqual (labels , legend_labels )
1060
+
1061
+ @slow
1062
+ def test_df_legend_labels (self ):
1063
+ kinds = 'line' , 'bar' , 'barh' , 'kde' , 'density'
1064
+ df = DataFrame (randn (3 , 3 ), columns = ['a' , 'b' , 'c' ])
1065
+ df2 = DataFrame (randn (3 , 3 ), columns = ['d' , 'e' , 'f' ])
1066
+ df3 = DataFrame (randn (3 , 3 ), columns = ['g' , 'h' , 'i' ])
1067
+ df4 = DataFrame (randn (3 , 3 ), columns = ['j' , 'k' , 'l' ])
1068
+
1069
+ for kind in kinds :
1070
+ ax = df .plot (kind = kind , legend = True )
1071
+ self ._check_legend_labels (ax , df .columns )
1072
+
1073
+ ax = df2 .plot (kind = kind , legend = False , ax = ax )
1074
+ self ._check_legend_labels (ax , df .columns )
1075
+
1076
+ ax = df3 .plot (kind = kind , legend = True , ax = ax )
1077
+ self ._check_legend_labels (ax , df .columns + df3 .columns )
1078
+
1079
+ ax = df4 .plot (kind = kind , legend = 'reverse' , ax = ax )
1080
+ expected = list (df .columns + df3 .columns ) + list (reversed (df4 .columns ))
1081
+ self ._check_legend_labels (ax , expected )
1082
+
1083
+ # Secondary Y
1084
+ ax = df .plot (legend = True , secondary_y = 'b' )
1085
+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' ])
1086
+ ax = df2 .plot (legend = False , ax = ax )
1087
+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' ])
1088
+ ax = df3 .plot (kind = 'bar' , legend = True , secondary_y = 'h' , ax = ax )
1089
+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' , 'g' , 'h (right)' , 'i' ])
1090
+
1091
+ # Time Series
1092
+ ind = date_range ('1/1/2014' , periods = 3 )
1093
+ df = DataFrame (randn (3 , 3 ), columns = ['a' , 'b' , 'c' ], index = ind )
1094
+ df2 = DataFrame (randn (3 , 3 ), columns = ['d' , 'e' , 'f' ], index = ind )
1095
+ df3 = DataFrame (randn (3 , 3 ), columns = ['g' , 'h' , 'i' ], index = ind )
1096
+ ax = df .plot (legend = True , secondary_y = 'b' )
1097
+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' ])
1098
+ ax = df2 .plot (legend = False , ax = ax )
1099
+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' ])
1100
+ ax = df3 .plot (legend = True , ax = ax )
1101
+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' , 'g' , 'h' , 'i' ])
1102
+
1103
+ # scatter
1104
+ ax = df .plot (kind = 'scatter' , x = 'a' , y = 'b' , label = 'data1' )
1105
+ self ._check_legend_labels (ax , ['data1' ])
1106
+ ax = df2 .plot (kind = 'scatter' , x = 'd' , y = 'e' , legend = False ,
1107
+ label = 'data2' , ax = ax )
1108
+ self ._check_legend_labels (ax , ['data1' ])
1109
+ ax = df3 .plot (kind = 'scatter' , x = 'g' , y = 'h' , label = 'data3' , ax = ax )
1110
+ self ._check_legend_labels (ax , ['data1' , 'data3' ])
1111
+
1049
1112
def test_legend_name (self ):
1050
1113
multi = DataFrame (randn (4 , 4 ),
1051
1114
columns = [np .array (['a' , 'a' , 'b' , 'b' ]),
@@ -1056,6 +1119,20 @@ def test_legend_name(self):
1056
1119
leg_title = ax .legend_ .get_title ()
1057
1120
self .assertEqual (leg_title .get_text (), 'group,individual' )
1058
1121
1122
+ df = DataFrame (randn (5 , 5 ))
1123
+ ax = df .plot (legend = True , ax = ax )
1124
+ leg_title = ax .legend_ .get_title ()
1125
+ self .assertEqual (leg_title .get_text (), 'group,individual' )
1126
+
1127
+ df .columns .name = 'new'
1128
+ ax = df .plot (legend = False , ax = ax )
1129
+ leg_title = ax .legend_ .get_title ()
1130
+ self .assertEqual (leg_title .get_text (), 'group,individual' )
1131
+
1132
+ ax = df .plot (legend = True , ax = ax )
1133
+ leg_title = ax .legend_ .get_title ()
1134
+ self .assertEqual (leg_title .get_text (), 'new' )
1135
+
1059
1136
def _check_plot_fails (self , f , * args , ** kwargs ):
1060
1137
with tm .assertRaises (Exception ):
1061
1138
f (* args , ** kwargs )
0 commit comments