Skip to content

Commit c50565c

Browse files
authored
Merge pull request #85 from paulgb/transformers-input-df
Add input_df init argument to pass df/series to transformers
2 parents b40328c + b51958a commit c50565c

File tree

3 files changed

+140
-4
lines changed

3 files changed

+140
-4
lines changed

README.rst

+34-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ For these examples, we'll also use pandas, numpy, and sklearn::
5050
Load some Data
5151
**************
5252

53-
Normally you'll read the data from a file, but for demonstration purposes I'll create a data frame from a Python dict::
53+
Normally you'll read the data from a file, but for demonstration purposes we'll create a data frame from a Python dict::
5454

5555
>>> data = pd.DataFrame({'pet': ['cat', 'dog', 'dog', 'fish', 'cat', 'dog', 'cat', 'fish'],
5656
... 'children': [4., 6, 3, 3, 2, 3, 5, 4],
@@ -116,6 +116,37 @@ the dataframe mapper. We can do so by inspecting the automatically generated
116116
['pet_cat', 'pet_dog', 'pet_fish', 'children']
117117

118118

119+
Passing Series/DataFrames to the transformers
120+
*********************************************
121+
122+
By default the transformers are passed a numpy array of the selected columns
123+
as input. This is because ``sklearn`` transformers are historically designed to
124+
work with numpy arrays, not with pandas dataframes, even though their basic
125+
indexing interfaces are similar.
126+
127+
However we can pass a dataframe/series to the transformers to handle custom
128+
cases initializing the dataframe mapper with ``input_df=True`::
129+
130+
>>> from sklearn.base import TransformerMixin
131+
>>> class DateEncoder(TransformerMixin):
132+
... def fit(self, X, y=None):
133+
... return self
134+
...
135+
... def transform(self, X):
136+
... dt = X.dt
137+
... return pd.concat([dt.year, dt.month, dt.day], axis=1)
138+
>>> dates_df = pd.DataFrame(
139+
... {'dates': pd.date_range('2015-10-30', '2015-11-02')})
140+
>>> mapper_dates = DataFrameMapper([
141+
... ('dates', DateEncoder())
142+
... ], input_df=True)
143+
>>> mapper_dates.fit_transform(dates_df)
144+
array([[2015, 10, 30],
145+
[2015, 10, 31],
146+
[2015, 11, 1],
147+
[2015, 11, 2]])
148+
149+
119150
Outputting a dataframe
120151
**********************
121152

@@ -289,6 +320,8 @@ Development
289320
* Capture output columns generated names in ``transformed_names_`` attribute (#78).
290321
* Add ``CategoricalImputer`` that replaces null-like values with the mode
291322
for string-like columns.
323+
* Add ``input_df`` init argument to allow inputting a dataframe/series to the
324+
transformers instead of a numpy array (#60).
292325

293326

294327
1.3.0 (2017-01-21)

sklearn_pandas/dataframe_mapper.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class DataFrameMapper(BaseEstimator, TransformerMixin):
3333
sklearn transformation.
3434
"""
3535

36-
def __init__(self, features, default=False, sparse=False, df_out=False):
36+
def __init__(self, features, default=False, sparse=False, df_out=False,
37+
input_df=False):
3738
"""
3839
Params:
3940
@@ -57,6 +58,10 @@ def __init__(self, features, default=False, sparse=False, df_out=False):
5758
if there's multiple inputs, and the name concatenated with
5859
'_1', '_2' etc if there's multiple outputs. NB: does not
5960
work if *default* or *sparse* are true
61+
62+
input_df If ``True`` pass the selected columns to the transformers
63+
as a pandas DataFrame or Series. Otherwise pass them as a
64+
numpy array. Defaults to ``False``.
6065
"""
6166
if isinstance(features, list):
6267
features = [(columns, _build_transformer(transformers))
@@ -65,6 +70,7 @@ def __init__(self, features, default=False, sparse=False, df_out=False):
6570
self.default = _build_transformer(default)
6671
self.sparse = sparse
6772
self.df_out = df_out
73+
self.input_df = input_df
6874
self.transformed_names_ = []
6975

7076
if (df_out and (sparse or default)):
@@ -108,6 +114,8 @@ def __setstate__(self, state):
108114
self.default = state.get('default', False)
109115
self.df_out = state.get('df_out', False)
110116

117+
self.input_df = state.get('input_df', False)
118+
111119
def _get_col_subset(self, X, cols):
112120
"""
113121
Get a subset of columns from the given table X.
@@ -132,10 +140,15 @@ def _get_col_subset(self, X, cols):
132140
X = X.df
133141

134142
if return_vector:
135-
t = X[cols[0]].values
143+
t = X[cols[0]]
136144
else:
137-
t = X[cols].values
145+
t = X[cols]
138146

147+
# return either a DataFrame/Series or a numpy array
148+
if self.input_df:
149+
return t
150+
else:
151+
return t.values
139152
return t
140153

141154
def fit(self, X, y=None):

tests/test_dataframe_mapper.py

+90
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ def predict(self, X):
5656
return True
5757

5858

59+
class DateEncoder():
60+
def fit(self, X, y=None):
61+
return self
62+
63+
def transform(self, X):
64+
dt = X.dt
65+
return pd.concat([dt.year, dt.month, dt.day], axis=1)
66+
67+
5968
class ToSparseTransformer(BaseEstimator, TransformerMixin):
6069
"""
6170
Transforms numpy matrix to sparse format.
@@ -225,6 +234,87 @@ def test_pca(complex_dataframe):
225234
assert cols[1] == 'feat1_feat2_1'
226235

227236

237+
def test_input_df_true_first_transformer(simple_dataframe, monkeypatch):
238+
"""
239+
If input_df is True, the first transformer is passed
240+
a pd.Series instead of an np.array
241+
"""
242+
df = simple_dataframe
243+
monkeypatch.setattr(MockXTransformer, 'fit', Mock())
244+
monkeypatch.setattr(MockXTransformer, 'transform',
245+
Mock(return_value=np.array([1, 2, 3])))
246+
mapper = DataFrameMapper([
247+
('a', MockXTransformer())
248+
], input_df=True)
249+
out = mapper.fit_transform(df)
250+
251+
args, _ = MockXTransformer().fit.call_args
252+
assert isinstance(args[0], pd.Series)
253+
254+
args, _ = MockXTransformer().transform.call_args
255+
assert isinstance(args[0], pd.Series)
256+
257+
assert_array_equal(out, np.array([1, 2, 3]).reshape(-1, 1))
258+
259+
260+
def test_input_df_true_next_transformers(simple_dataframe, monkeypatch):
261+
"""
262+
If input_df is True, the subsequent transformers get passed pandas
263+
objects instead of numpy arrays (given the previous transformers
264+
output pandas objects as well)
265+
"""
266+
df = simple_dataframe
267+
monkeypatch.setattr(MockTClassifier, 'fit', Mock())
268+
monkeypatch.setattr(MockTClassifier, 'transform',
269+
Mock(return_value=pd.Series([1, 2, 3])))
270+
mapper = DataFrameMapper([
271+
('a', [MockXTransformer(), MockTClassifier()])
272+
], input_df=True)
273+
out = mapper.fit_transform(df)
274+
275+
args, _ = MockTClassifier().fit.call_args
276+
assert isinstance(args[0], pd.Series)
277+
278+
assert_array_equal(out, np.array([1, 2, 3]).reshape(-1, 1))
279+
280+
281+
def test_input_df_true_multiple_cols(complex_dataframe):
282+
"""
283+
When input_df is True, applying transformers to multiple columns
284+
works as expected
285+
"""
286+
df = complex_dataframe
287+
288+
mapper = DataFrameMapper([
289+
('target', MockXTransformer()),
290+
('feat1', MockXTransformer()),
291+
], input_df=True)
292+
out = mapper.fit_transform(df)
293+
294+
assert_array_equal(out[:, 0], df['target'].values)
295+
assert_array_equal(out[:, 1], df['feat1'].values)
296+
297+
298+
def test_input_df_date_encoder():
299+
"""
300+
When input_df is True we can apply a transformer that only works
301+
with pandas dataframes like a DateEncoder
302+
"""
303+
df = pd.DataFrame(
304+
{'dates': pd.date_range('2015-10-30', '2015-11-02')})
305+
mapper = DataFrameMapper([
306+
('dates', DateEncoder())
307+
], input_df=True)
308+
out = mapper.fit_transform(df)
309+
expected = np.array([
310+
[2015, 10, 30],
311+
[2015, 10, 31],
312+
[2015, 11, 1],
313+
[2015, 11, 2]
314+
])
315+
assert_array_equal(out, expected)
316+
317+
228318
def test_nonexistent_columns_explicit_fail(simple_dataframe):
229319
"""
230320
If a nonexistent column is selected, KeyError is raised.

0 commit comments

Comments
 (0)