Skip to content

Commit 1b4edd9

Browse files
DataFrameMapper.inverse_transform() for simple transformations
1 parent d2cd6bb commit 1b4edd9

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

sklearn_pandas/dataframe_mapper.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(self, features, default=False, sparse=False, df_out=False,
110110
self.df_out = df_out
111111
self.input_df = input_df
112112
self.transformed_names_ = []
113+
self.transformed_cols_ = []
113114

114115
if (df_out and (sparse or default)):
115116
raise ValueError("Can not use df_out with sparse or default")
@@ -268,6 +269,7 @@ def transform(self, X):
268269
"""
269270
extracted = []
270271
self.transformed_names_ = []
272+
self.transformed_cols_ = []
271273
for columns, transformers, options in self.built_features:
272274
input_df = options.get('input_df', self.input_df)
273275
# columns could be a string or list of
@@ -282,6 +284,10 @@ def transform(self, X):
282284
alias = options.get('alias')
283285
self.transformed_names_ += self.get_names(
284286
columns, transformers, Xt, alias)
287+
288+
self.transformed_cols_ += [
289+
(columns, transformers,
290+
self.get_names(columns, transformers, Xt, alias)) ]
285291

286292
# handle features not explicitly selected
287293
if self.built_default is not False:
@@ -328,3 +334,34 @@ def transform(self, X):
328334
index=index)
329335
else:
330336
return stacked
337+
338+
339+
def inverse_transform(self, X):
340+
"""
341+
Inverse transform the given data. Assumes that fit has already been called.
342+
343+
X the data to inverse transform
344+
"""
345+
346+
X_inv = pd.DataFrame()
347+
# We will populate the inverse transformed dataframe column by column
348+
349+
# Let's keep track of the column we've processed
350+
prev_col = 0
351+
for columns, transformers, transformed_cols in self.transformed_cols_:
352+
# Determine the column number of the last column in X corresponding to
353+
# the original column we're computing
354+
last_col = prev_col + len(transformed_cols)
355+
356+
# Inverse transform the columns in X for the current transformer
357+
col_inv = pd.DataFrame(transformers.inverse_transform(X[:, prev_col:last_col]),
358+
columns = [columns])
359+
360+
# Append the inverse transformed column to the output data frame
361+
X_inv = pd.concat([X_inv, col_inv], axis = 1)
362+
363+
# For the next iteration, update the last column processed
364+
prev_col = last_col
365+
366+
367+
return X_inv

tests/test_dataframe_mapper.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from sklearn.feature_extraction.text import CountVectorizer
2020
from sklearn.feature_extraction import DictVectorizer
2121
from sklearn.preprocessing import (
22-
Imputer, StandardScaler, OneHotEncoder, LabelBinarizer)
22+
Imputer, StandardScaler, OneHotEncoder, LabelBinarizer, LabelEncoder)
2323
from sklearn.feature_selection import SelectKBest, chi2
2424
from sklearn.base import BaseEstimator, TransformerMixin
2525
import sklearn.decomposition
@@ -829,3 +829,32 @@ def test_direct_cross_validation(iris_dataframe):
829829
scores = sklearn_cv_score(pipeline, data, labels)
830830
assert scores.mean() > 0.96
831831
assert (scores.std() * 2) < 0.04
832+
833+
834+
def test_inverse_transform_simple():
835+
df = pd.DataFrame({'colA': list('ynyyn'), 'colB': list('abcab')})
836+
mapper = DataFrameMapper([
837+
('colA', LabelEncoder()),
838+
('colB', LabelEncoder()),
839+
])
840+
841+
transformed = mapper.fit_transform(df)
842+
restored = mapper.inverse_transform(transformed)
843+
844+
assert isinstance(restored, pd.DataFrame)
845+
assert restored.equals(df)
846+
847+
848+
def test_inverse_transform_multicolumn():
849+
df = pd.DataFrame({'colA': list('ynyyn'), 'colB': list('abcab'), 'colC': list('sttts')})
850+
mapper = DataFrameMapper([
851+
('colA', LabelEncoder()),
852+
('colB', LabelBinarizer()),
853+
('colC', LabelEncoder()),
854+
])
855+
856+
transformed = mapper.fit_transform(df)
857+
restored = mapper.inverse_transform(transformed)
858+
859+
assert isinstance(restored, pd.DataFrame)
860+
assert restored.equals(df)

0 commit comments

Comments
 (0)