Skip to content

Commit 015b31e

Browse files
Chang Shewesm
Chang She
authored andcommitted
ENH: make method signature more consistent with new statsmodels behavior. Uses dot product directly so pandas users aren't affected by statsmodels API change
1 parent dd5205f commit 015b31e

File tree

2 files changed

+46
-19
lines changed

2 files changed

+46
-19
lines changed

pandas/stats/ols.py

+31-16
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111

12-
from pandas.core.api import DataFrame, Series
12+
from pandas.core.api import DataFrame, Series, isnull
1313
from pandas.core.common import _ensure_float64
1414
from pandas.core.index import MultiIndex
1515
from pandas.core.panel import Panel
@@ -381,12 +381,13 @@ def y_predict(self):
381381
For in-sample, this is same as y_fitted."""
382382
return self.y_fitted
383383

384-
def predict(self, new_y_values, fill_value=None, fill_method=None,
385-
axis=0):
384+
def predict(self, beta=None, x=None, fill_value=None,
385+
fill_method=None, axis=0):
386386
"""
387387
Parameters
388388
----------
389-
new_y_values : Series or DataFrame
389+
beta : Series
390+
x : Series or DataFrame
390391
fill_value : scalar or dict, default None
391392
fill_method : {'backfill', 'bfill', 'pad', 'ffill', None}, default None
392393
axis : {0, 1}, default 0
@@ -403,20 +404,34 @@ def predict(self, new_y_values, fill_value=None, fill_method=None,
403404
-------
404405
Series of predicted values
405406
"""
406-
orig_y = new_y_values
407-
if fill_value is None and fill_method is None:
408-
new_y_values = new_y_values.dropna(how='any')
407+
if beta is None and x is None:
408+
return self.y_predict
409+
410+
if beta is None:
411+
beta = self.beta
409412
else:
410-
new_y_values = new_y_values.fillna(value=fill_value,
411-
method=fill_method, axis=axis)
412-
if isinstance(new_y_values, Series):
413-
new_y_values = DataFrame({'x' : new_y_values})
414-
if self._intercept:
415-
new_y_values['intercept'] = 1.
413+
beta = beta.reindex(self.beta.index)
414+
if isnull(beta).any():
415+
raise ValueError('Must supply betas for same variables')
416+
417+
if x is None:
418+
x = self._x
419+
orig_x = x
420+
else:
421+
orig_x = x
422+
if fill_value is None and fill_method is None:
423+
x = x.dropna(how='any')
424+
else:
425+
x = x.fillna(value=fill_value, method=fill_method, axis=axis)
426+
if isinstance(x, Series):
427+
x = DataFrame({'x' : x})
428+
if self._intercept:
429+
x['intercept'] = 1.
430+
431+
x = x.reindex(columns=self._x.columns)
416432

417-
new_y_values = new_y_values.reindex(columns=self._x.columns)
418-
rs = self.sm_ols.model.predict(new_y_values.values)
419-
return Series(rs, new_y_values.index).reindex(orig_y.index)
433+
rs = x.values.dot(beta.values)
434+
return Series(rs, x.index).reindex(orig_x.index)
420435

421436
RESULT_FIELDS = ['r2', 'r2_adj', 'df', 'df_model', 'df_resid', 'rmse',
422437
'f_stat', 'beta', 'std_err', 't_stat', 'p_value', 'nobs']

pandas/stats/tests/test_ols.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -291,17 +291,29 @@ def test_predict(self):
291291
y = tm.makeTimeSeries()
292292
x = tm.makeTimeDataFrame()
293293
model1 = ols(y=y, x=x)
294-
assert_series_equal(model1.predict(x), model1.y_predict)
294+
assert_series_equal(model1.predict(), model1.y_predict)
295+
assert_series_equal(model1.predict(x=x), model1.y_predict)
296+
assert_series_equal(model1.predict(beta=model1.beta), model1.y_predict)
297+
298+
exog = x.copy()
299+
exog['intercept'] = 1.
300+
rs = Series(exog.values.dot(model1.beta.values), x.index)
301+
assert_series_equal(model1.y_predict, rs)
302+
295303
x2 = x.reindex(columns=x.columns[::-1])
296-
assert_series_equal(model1.predict(x2), model1.y_predict)
304+
assert_series_equal(model1.predict(x=x2), model1.y_predict)
297305

298306
x3 = x2 + 10
299-
pred3 = model1.predict(x3)
307+
pred3 = model1.predict(x=x3)
300308
x3['intercept'] = 1.
301309
x3 = x3.reindex(columns = model1.beta.index)
302310
expected = Series(x3.values.dot(model1.beta.values), x3.index)
303311
assert_series_equal(expected, pred3)
304312

313+
beta = Series(0., model1.beta.index)
314+
pred4 = model1.predict(beta=beta)
315+
assert_series_equal(Series(0., pred4.index), pred4)
316+
305317
def test_longpanel_series_combo(self):
306318
wp = tm.makePanel()
307319
lp = wp.to_frame()

0 commit comments

Comments
 (0)