Skip to content

Commit dd5205f

Browse files
Chang Shewesm
Chang She
authored andcommitted
ENH: added OLS.predict method; pass through call to statsmodels ols predict method
1 parent 8df966c commit dd5205f

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

pandas/stats/ols.py

+37
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,43 @@ def y_predict(self):
381381
For in-sample, this is same as y_fitted."""
382382
return self.y_fitted
383383

384+
def predict(self, new_y_values, fill_value=None, fill_method=None,
385+
axis=0):
386+
"""
387+
Parameters
388+
----------
389+
new_y_values : Series or DataFrame
390+
fill_value : scalar or dict, default None
391+
fill_method : {'backfill', 'bfill', 'pad', 'ffill', None}, default None
392+
axis : {0, 1}, default 0
393+
See DataFrame.fillna for more details
394+
395+
Notes
396+
-----
397+
1. If both fill_value and fill_method are None then NaNs are dropped
398+
(this is the default behavior)
399+
2. An intercept will be automatically added to the new_y_values if
400+
the model was fitted using an intercept
401+
402+
Returns
403+
-------
404+
Series of predicted values
405+
"""
406+
orig_y = new_y_values
407+
if fill_value is None and fill_method is None:
408+
new_y_values = new_y_values.dropna(how='any')
409+
else:
410+
new_y_values = new_y_values.fillna(value=fill_value,
411+
method=fill_method, axis=axis)
412+
if isinstance(new_y_values, Series):
413+
new_y_values = DataFrame({'x' : new_y_values})
414+
if self._intercept:
415+
new_y_values['intercept'] = 1.
416+
417+
new_y_values = new_y_values.reindex(columns=self._x.columns)
418+
rs = self.sm_ols.model.predict(new_y_values.values)
419+
return Series(rs, new_y_values.index).reindex(orig_y.index)
420+
384421
RESULT_FIELDS = ['r2', 'r2_adj', 'df', 'df_model', 'df_resid', 'rmse',
385422
'f_stat', 'beta', 'std_err', 't_stat', 'p_value', 'nobs']
386423

pandas/stats/tests/test_ols.py

+15
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,21 @@ def test_y_predict(self):
287287
assert_series_equal(model1.y_predict, model1.y_fitted)
288288
assert_almost_equal(model1._y_predict_raw, model1._y_fitted_raw)
289289

290+
def test_predict(self):
291+
y = tm.makeTimeSeries()
292+
x = tm.makeTimeDataFrame()
293+
model1 = ols(y=y, x=x)
294+
assert_series_equal(model1.predict(x), model1.y_predict)
295+
x2 = x.reindex(columns=x.columns[::-1])
296+
assert_series_equal(model1.predict(x2), model1.y_predict)
297+
298+
x3 = x2 + 10
299+
pred3 = model1.predict(x3)
300+
x3['intercept'] = 1.
301+
x3 = x3.reindex(columns = model1.beta.index)
302+
expected = Series(x3.values.dot(model1.beta.values), x3.index)
303+
assert_series_equal(expected, pred3)
304+
290305
def test_longpanel_series_combo(self):
291306
wp = tm.makePanel()
292307
lp = wp.to_frame()

0 commit comments

Comments
 (0)