@@ -46,6 +46,8 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
46
46
>>> df = DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
47
47
>>> scatter_matrix(df, alpha=0.2)
48
48
"""
49
+ from matplotlib .artist import setp
50
+
49
51
df = frame ._get_numeric_data ()
50
52
n = df .columns .size
51
53
fig , axes = _subplots (nrows = n , ncols = n , figsize = figsize , ax = ax ,
@@ -60,76 +62,74 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
60
62
61
63
for i , a in zip (range (n ), df .columns ):
62
64
for j , b in zip (range (n ), df .columns ):
65
+ ax = axes [i , j ]
66
+
63
67
if i == j :
64
68
values = df [a ].values [mask [a ].values ]
65
69
66
70
# Deal with the diagonal by drawing a histogram there.
67
71
if diagonal == 'hist' :
68
- axes [ i , j ] .hist (values )
72
+ ax .hist (values )
69
73
elif diagonal in ('kde' , 'density' ):
70
74
from scipy .stats import gaussian_kde
71
75
y = values
72
76
gkde = gaussian_kde (y )
73
77
ind = np .linspace (y .min (), y .max (), 1000 )
74
- axes [ i , j ] .plot (ind , gkde .evaluate (ind ), ** kwds )
78
+ ax .plot (ind , gkde .evaluate (ind ), ** kwds )
75
79
else :
76
80
common = (mask [a ] & mask [b ]).values
77
81
78
- axes [ i , j ] .scatter (df [b ][common ], df [a ][common ],
82
+ ax .scatter (df [b ][common ], df [a ][common ],
79
83
marker = marker , alpha = alpha , ** kwds )
80
84
81
- axes [i , j ].set_xlabel ('' )
82
- axes [i , j ].set_ylabel ('' )
83
- axes [i , j ].set_xticklabels ([])
84
- axes [i , j ].set_yticklabels ([])
85
- ticks = df .index
86
-
87
- is_datetype = ticks .inferred_type in ('datetime' , 'date' ,
88
- 'datetime64' )
85
+ ax .set_xlabel ('' )
86
+ ax .set_ylabel ('' )
89
87
90
- if ticks .is_numeric () or is_datetype :
91
- """
92
- Matplotlib supports numeric values or datetime objects as
93
- xaxis values. Taking LBYL approach here, by the time
94
- matplotlib raises exception when using non numeric/datetime
95
- values for xaxis, several actions are already taken by plt.
96
- """
97
- ticks = ticks ._mpl_repr ()
88
+ ax .xaxis .set_visible (False )
89
+ ax .yaxis .set_visible (False )
98
90
99
91
# setup labels
100
92
if i == 0 and j % 2 == 1 :
101
- axes [i , j ].set_xlabel (b , visible = True )
102
- #axes[i, j].xaxis.set_visible(True)
103
- axes [i , j ].set_xlabel (b )
104
- axes [i , j ].set_xticklabels (ticks )
105
- axes [i , j ].xaxis .set_ticks_position ('top' )
106
- axes [i , j ].xaxis .set_label_position ('top' )
107
- if i == n - 1 and j % 2 == 0 :
108
- axes [i , j ].set_xlabel (b , visible = True )
109
- #axes[i, j].xaxis.set_visible(True)
110
- axes [i , j ].set_xlabel (b )
111
- axes [i , j ].set_xticklabels (ticks )
112
- axes [i , j ].xaxis .set_ticks_position ('bottom' )
113
- axes [i , j ].xaxis .set_label_position ('bottom' )
114
- if j == 0 and i % 2 == 0 :
115
- axes [i , j ].set_ylabel (a , visible = True )
116
- #axes[i, j].yaxis.set_visible(True)
117
- axes [i , j ].set_ylabel (a )
118
- axes [i , j ].set_yticklabels (ticks )
119
- axes [i , j ].yaxis .set_ticks_position ('left' )
120
- axes [i , j ].yaxis .set_label_position ('left' )
121
- if j == n - 1 and i % 2 == 1 :
122
- axes [i , j ].set_ylabel (a , visible = True )
123
- #axes[i, j].yaxis.set_visible(True)
124
- axes [i , j ].set_ylabel (a )
125
- axes [i , j ].set_yticklabels (ticks )
126
- axes [i , j ].yaxis .set_ticks_position ('right' )
127
- axes [i , j ].yaxis .set_label_position ('right' )
128
-
129
- axes [i , j ].grid (b = grid )
93
+ ax .set_xlabel (b , visible = True )
94
+ ax .xaxis .set_visible (True )
95
+ ax .set_xlabel (b )
96
+ ax .xaxis .set_ticks_position ('top' )
97
+ ax .xaxis .set_label_position ('top' )
98
+ setp (ax .get_xticklabels (), rotation = 90 )
99
+ elif i == n - 1 and j % 2 == 0 :
100
+ ax .set_xlabel (b , visible = True )
101
+ ax .xaxis .set_visible (True )
102
+ ax .set_xlabel (b )
103
+ ax .xaxis .set_ticks_position ('bottom' )
104
+ ax .xaxis .set_label_position ('bottom' )
105
+ setp (ax .get_xticklabels (), rotation = 90 )
106
+ elif j == 0 and i % 2 == 0 :
107
+ ax .set_ylabel (a , visible = True )
108
+ ax .yaxis .set_visible (True )
109
+ ax .set_ylabel (a )
110
+ ax .yaxis .set_ticks_position ('left' )
111
+ ax .yaxis .set_label_position ('left' )
112
+ elif j == n - 1 and i % 2 == 1 :
113
+ ax .set_ylabel (a , visible = True )
114
+ ax .yaxis .set_visible (True )
115
+ ax .set_ylabel (a )
116
+ ax .yaxis .set_ticks_position ('right' )
117
+ ax .yaxis .set_label_position ('right' )
118
+
119
+ # ax.grid(b=grid)
120
+
121
+ axes [0 , 0 ].yaxis .set_visible (False )
122
+ axes [n - 1 , n - 1 ].xaxis .set_visible (False )
123
+ axes [n - 1 , n - 1 ].yaxis .set_visible (False )
124
+ axes [0 , n - 1 ].yaxis .tick_right ()
125
+
126
+ for ax in axes .flat :
127
+ setp (ax .get_xticklabels (), fontsize = 8 )
128
+ setp (ax .get_yticklabels (), fontsize = 8 )
130
129
131
130
return axes
132
131
132
+
133
133
def _gca ():
134
134
import matplotlib .pyplot as plt
135
135
return plt .gca ()
0 commit comments