Skip to content

Commit 13f0d15

Browse files
committed
BUG: address GH #106, and various ols unit tests
1 parent f3a183f commit 13f0d15

File tree

4 files changed

+77
-23
lines changed

4 files changed

+77
-23
lines changed

pandas/stats/interface.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from pandas.core.api import Series
1+
from pandas.core.api import (Series, DataFrame, WidePanel, LongPanel,
2+
MultiIndex)
23

34
from pandas.stats.ols import OLS, MovingOLS
45
from pandas.stats.plm import PanelOLS, MovingPanelOLS, NonPooledPanelOLS
@@ -91,27 +92,32 @@ def ols(**kwargs):
9192
if window_type != common.FULL_SAMPLE:
9293
kwargs['window_type'] = common._get_window_type_name(window_type)
9394

94-
y = kwargs.get('y')
95+
x = kwargs.get('x')
96+
if isinstance(x, dict):
97+
if isinstance(x.values()[0], DataFrame):
98+
x = WidePanel(x)
99+
else:
100+
x = DataFrame(x)
101+
95102
if window_type == common.FULL_SAMPLE:
96-
# HACK (!)
97103
for rolling_field in ('window_type', 'window', 'min_periods'):
98104
if rolling_field in kwargs:
99105
del kwargs[rolling_field]
100106

101-
if isinstance(y, Series):
102-
klass = OLS
103-
else:
107+
if isinstance(x, (WidePanel, LongPanel)):
104108
if pool == False:
105109
klass = NonPooledPanelOLS
106110
else:
107111
klass = PanelOLS
108-
else:
109-
if isinstance(y, Series):
110-
klass = MovingOLS
111112
else:
113+
klass = OLS
114+
else:
115+
if isinstance(x, (WidePanel, LongPanel)):
112116
if pool == False:
113117
klass = NonPooledPanelOLS
114118
else:
115119
klass = MovingPanelOLS
120+
else:
121+
klass = MovingOLS
116122

117123
return klass(**kwargs)

pandas/stats/ols.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pandas.stats.math as math
1717
import pandas.stats.moments as moments
1818

19-
_FP_ERR = 1e-13
19+
_FP_ERR = 1e-8
2020

2121
class OLS(object):
2222
"""
@@ -242,7 +242,6 @@ def p_value(self):
242242
def _r2_raw(self):
243243
"""Returns the raw r-squared values."""
244244
has_intercept = np.abs(self._resid_raw.sum()) < _FP_ERR
245-
246245
if self._intercept:
247246
return 1 - self.sm_ols.ssr / self.sm_ols.centered_tss
248247
else:
@@ -1176,7 +1175,8 @@ def _filter_data(lhs, rhs):
11761175
Cleaned lhs and rhs
11771176
"""
11781177
if not isinstance(lhs, Series):
1179-
raise Exception('lhs must be a Series')
1178+
assert(len(lhs) == len(rhs))
1179+
lhs = Series(lhs, index=rhs.index)
11801180

11811181
rhs = _combine_rhs(rhs)
11821182

pandas/stats/plm.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(self, y, x, weights=None,
7171
dropped_dummies=None, verbose=False, nw_overlap=False):
7272
self._x_orig = x
7373
self._y_orig = y
74+
7475
self._weights = weights
7576
self._intercept = intercept
7677
self._nw_lags = nw_lags
@@ -171,7 +172,13 @@ def _filter_data(self):
171172
filtered = data.to_long()
172173

173174
# Filter all data together using to_long
174-
data['__y__'] = self._y_orig
175+
176+
# convert to DataFrame
177+
y = self._y_orig
178+
if isinstance(y, Series):
179+
y = y.unstack()
180+
181+
data['__y__'] = y
175182
data_long = data.to_long()
176183

177184
x_filt = filtered.filter(x_names)

pandas/stats/tests/test_ols.py

+51-10
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
import numpy as np
1212

1313
from pandas.core.panel import LongPanel
14-
from pandas.core.api import DataFrame, Index, Series
14+
from pandas.core.api import DataFrame, Index, Series, notnull
1515
from pandas.stats.api import ols
16-
from pandas.stats.plm import NonPooledPanelOLS
16+
from pandas.stats.plm import NonPooledPanelOLS, PanelOLS
1717
from pandas.util.testing import (assert_almost_equal, assert_series_equal,
1818
assert_frame_equal)
19-
import pandas.util.testing as testing
19+
import pandas.util.testing as tm
2020

2121
from common import BaseTest
2222

@@ -40,10 +40,6 @@ def _compare_moving_ols(model1, model2):
4040

4141
class TestOLS(BaseTest):
4242

43-
FIELDS = ['beta', 'df', 'df_model', 'df_resid', 'f_stat', 'p_value',
44-
'r2', 'r2_adj', 'rmse', 'std_err', 't_stat',
45-
'var_beta']
46-
4743
# TODO: Add tests for OLS y predict
4844
# TODO: Right now we just check for consistency between full-sample and
4945
# rolling/expanding results of the panel OLS. We should also cross-check
@@ -140,6 +136,10 @@ def checkMovingOLS(self, window_type, x, y, **kwds):
140136

141137
_check_non_raw_results(moving)
142138

139+
FIELDS = ['beta', 'df', 'df_model', 'df_resid', 'f_stat', 'p_value',
140+
'r2', 'r2_adj', 'rmse', 'std_err', 't_stat',
141+
'var_beta']
142+
143143
def compare(self, static, moving, event_index=None,
144144
result_index=None):
145145

@@ -169,7 +169,7 @@ def compare(self, static, moving, event_index=None,
169169
assert_almost_equal(ref, res)
170170

171171
def test_f_test(self):
172-
x = testing.makeTimeDataFrame()
172+
x = tm.makeTimeDataFrame()
173173
y = x.pop('A')
174174

175175
model = ols(y=y, x=x)
@@ -185,8 +185,49 @@ def test_f_test(self):
185185

186186
self.assertRaises(Exception, model.f_test, '1*A=0')
187187

188-
class TestPanelOLS(BaseTest):
188+
class TestOLSMisc(unittest.TestCase):
189+
'''
190+
For test coverage with faux data
191+
'''
192+
193+
def test_r2_no_intercept(self):
194+
y = tm.makeTimeSeries()
195+
x = tm.makeTimeDataFrame()
189196

197+
model1 = ols(y=y, x=x)
198+
199+
x_with = x.copy()
200+
x_with['intercept'] = 1.
201+
202+
model2 = ols(y=y, x=x_with, intercept=False)
203+
assert_series_equal(model1.beta, model2.beta)
204+
205+
# TODO: can we infer whether the intercept is there...
206+
self.assert_(model1.r2 != model2.r2)
207+
208+
def test_summary_many_terms(self):
209+
x = DataFrame(np.random.randn(100, 20))
210+
y = np.random.randn(100)
211+
model = ols(y=y, x=x)
212+
model.summary
213+
214+
def test_y_predict(self):
215+
y = tm.makeTimeSeries()
216+
x = tm.makeTimeDataFrame()
217+
model1 = ols(y=y, x=x)
218+
assert_series_equal(model1.y_predict, model1.y_fitted)
219+
220+
def test_longpanel_series_combo(self):
221+
wp = tm.makeWidePanel()
222+
lp = wp.to_long()
223+
224+
y = lp.pop('ItemA')
225+
model = ols(y=y, x=lp, entity_effects=True, window=20)
226+
self.assert_(notnull(model.beta.values).all())
227+
self.assert_(isinstance(model, PanelOLS))
228+
model.summary
229+
230+
class TestPanelOLS(BaseTest):
190231

191232
FIELDS = ['beta', 'df', 'df_model', 'df_resid', 'f_stat',
192233
'p_value', 'r2', 'r2_adj', 'rmse', 'std_err',
@@ -501,7 +542,7 @@ def compare(self, static, moving, event_index=None,
501542
assert_almost_equal(ref, res)
502543

503544
def test_auto_rolling_window_type(self):
504-
data = testing.makeTimeDataFrame()
545+
data = tm.makeTimeDataFrame()
505546
y = data.pop('A')
506547

507548
window_model = ols(y=y, x=data, window=20, min_periods=10)

0 commit comments

Comments
 (0)