Skip to content

Commit a4c916b

Browse files
committed
Convert inverse transform to use warnings for cases where bijection is broken and test that we get nulls when bijection is broken
1 parent dc5ac1b commit a4c916b

File tree

4 files changed

+129
-30
lines changed

4 files changed

+129
-30
lines changed

category_encoders/basen.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sklearn.base import BaseEstimator, TransformerMixin
88
from category_encoders.ordinal import OrdinalEncoder
99
import category_encoders.utils as util
10+
import warnings
1011

1112
__author__ = 'willmcginnis'
1213

@@ -276,23 +277,17 @@ def inverse_transform(self, X_in):
276277
if not self.cols:
277278
return X if self.return_df else X.values
278279

279-
if self.handle_unknown == 'value':
280-
for col in self.cols:
281-
if any(X[col] == -1):
282-
raise ValueError("inverse_transform is not supported because transform impute "
283-
"the unknown category -1 when encode %s" % (col,))
284-
285-
if self.handle_unknown == 'return_nan':
286-
for col in self.cols:
287-
if X[col].isnull().any():
288-
raise ValueError("inverse_transform is not supported because transform impute "
289-
"the unknown category nan when encode %s" % (col,))
290-
291280
for switch in self.ordinal_encoder.mapping:
292281
column_mapping = switch.get('mapping')
293282
inverse = pd.Series(data=column_mapping.index, index=column_mapping.get_values())
294283
X[switch.get('col')] = X[switch.get('col')].map(inverse).astype(switch.get('data_type'))
295284

285+
if self.handle_unknown == 'return_nan' and self.handle_missing == 'return_nan':
286+
for col in self.cols:
287+
if X[switch.get('col')].isnull().any():
288+
warnings.warn("inverse_transform is not supported because transform impute "
289+
"the unknown category nan when encode %s" % (col,))
290+
296291
return X if self.return_df else X.values
297292

298293
def calc_required_digits(self, values):
@@ -356,10 +351,7 @@ def basen_to_integer(self, X, cols, base):
356351

357352
for col in cols:
358353
col_list = [col0 for col0 in out_cols if str(col0).startswith(str(col))]
359-
for col0 in col_list:
360-
if any(X[col0].isnull()):
361-
raise ValueError("inverse_transform is not supported because transform impute"
362-
"the unknown category -1 when encode %s" % (col,))
354+
363355
if base == 1:
364356
value_array = np.array([int(col0.split('_')[-1]) for col0 in col_list])
365357
else:

category_encoders/one_hot.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""One-hot or dummy coding"""
22
import numpy as np
33
import pandas as pd
4-
import copy
4+
import warnings
55
from sklearn.base import BaseEstimator, TransformerMixin
66
from category_encoders.ordinal import OrdinalEncoder
77
import category_encoders.utils as util
@@ -306,23 +306,17 @@ def inverse_transform(self, X_in):
306306
if not self.cols:
307307
return X if self.return_df else X.values
308308

309-
if self.handle_unknown == 'value':
310-
for col in self.cols:
311-
if any(X[col] == -1):
312-
raise ValueError("inverse_transform is not supported because transform impute "
313-
"the unknown category -1 when encode %s"%(col,))
314-
315-
if self.handle_unknown == 'return_nan':
316-
for col in self.cols:
317-
if X[col].isnull().any():
318-
raise ValueError("inverse_transform is not supported because transform impute "
319-
"the unknown category nan when encode %s" % (col,))
320-
321309
for switch in self.ordinal_encoder.mapping:
322310
column_mapping = switch.get('mapping')
323311
inverse = pd.Series(data=column_mapping.index, index=column_mapping.get_values())
324312
X[switch.get('col')] = X[switch.get('col')].map(inverse).astype(switch.get('data_type'))
325313

314+
if self.handle_unknown == 'return_nan' and self.handle_missing == 'return_nan':
315+
for col in self.cols:
316+
if X[switch.get('col')].isnull().any():
317+
warnings.warn("inverse_transform is not supported because transform impute "
318+
"the unknown category nan when encode %s" % (col,))
319+
326320
return X if self.return_df else X.values
327321

328322
def get_dummies(self, X_in):

category_encoders/tests/test_basen.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest2 import TestCase # or `from unittest import ...` if on Python 3.4+
33
import numpy as np
44
import category_encoders as encoders
5+
import warnings
56

67

78
class TestBaseNEncoder(TestCase):
@@ -84,3 +85,59 @@ def test_HandleUnknown_HaveOnlyKnown_ExpectSecondColumn(self):
8485
self.assertEqual(2, result.shape[0])
8586
self.assertListEqual([0, 0, 1], result.iloc[0, :].tolist())
8687
self.assertListEqual([0, 1, 0], result.iloc[1, :].tolist())
88+
89+
def test_inverse_transform_HaveNanInTrainAndHandleMissingValue_ExpectReturnedWithNan(self):
90+
train = pd.DataFrame({'city': ['chicago', np.nan]})
91+
92+
enc = encoders.BaseNEncoder(handle_missing='value', handle_unknown='value')
93+
result = enc.fit_transform(train)
94+
original = enc.inverse_transform(result)
95+
96+
pd.testing.assert_frame_equal(train, original)
97+
98+
def test_inverse_transform_HaveNanInTrainAndHandleMissingReturnNan_ExpectReturnedWithNan(self):
99+
train = pd.DataFrame({'city': ['chicago', np.nan]})
100+
101+
enc = encoders.BaseNEncoder(handle_missing='return_nan', handle_unknown='value')
102+
result = enc.fit_transform(train)
103+
original = enc.inverse_transform(result)
104+
105+
pd.testing.assert_frame_equal(train, original)
106+
107+
def test_inverse_transform_BothFieldsAreReturnNanWithNan_ExpectValueError(self):
108+
train = pd.DataFrame({'city': ['chicago', np.nan]})
109+
test = pd.DataFrame({'city': ['chicago', 'los angeles']})
110+
111+
enc = encoders.BaseNEncoder(handle_missing='return_nan', handle_unknown='return_nan')
112+
enc.fit(train)
113+
result = enc.transform(test)
114+
115+
with warnings.catch_warnings(record=True) as w:
116+
enc.inverse_transform(result)
117+
118+
self.assertEqual(2, len(w))
119+
self.assertEqual('inverse_transform is not supported because transform impute '
120+
'the unknown category nan when encode city', str(w[1].message))
121+
122+
def test_inverse_transform_HaveMissingAndNoUnknown_ExpectInversed(self):
123+
train = pd.DataFrame({'city': ['chicago', np.nan]})
124+
test = pd.DataFrame({'city': ['chicago', 'los angeles']})
125+
126+
enc = encoders.BaseNEncoder(handle_missing='value', handle_unknown='return_nan')
127+
enc.fit(train)
128+
result = enc.transform(test)
129+
original = enc.inverse_transform(result)
130+
131+
pd.testing.assert_frame_equal(train, original)
132+
133+
def test_inverse_transform_HaveHandleMissingValueAndHandleUnknownReturnNan_ExpectBestInverse(self):
134+
train = pd.DataFrame({'city': ['chicago', np.nan]})
135+
test = pd.DataFrame({'city': ['chicago', np.nan, 'los angeles']})
136+
expected = pd.DataFrame({'city': ['chicago', np.nan, np.nan]})
137+
138+
enc = encoders.BaseNEncoder(handle_missing='value', handle_unknown='return_nan')
139+
enc.fit(train)
140+
result = enc.transform(test)
141+
original = enc.inverse_transform(result)
142+
143+
pd.testing.assert_frame_equal(expected, original)

category_encoders/tests/test_one_hot.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22
from unittest import TestCase # or `from unittest import ...` if on Python 3.4+
33
import numpy as np
4-
4+
import warnings
55
import category_encoders.tests.test_utils as tu
66

77
import category_encoders as encoders
@@ -201,3 +201,59 @@ def test_HandleUnknown_HaveOnlyKnown_ExpectSecondColumn(self):
201201
expected = [[1, 0, 0],
202202
[0, 1, 0]]
203203
self.assertEqual(result.values.tolist(), expected)
204+
205+
def test_inverse_transform_HaveNanInTrainAndHandleMissingValue_ExpectReturnedWithNan(self):
206+
train = pd.DataFrame({'city': ['chicago', np.nan]})
207+
208+
enc = encoders.OneHotEncoder(handle_missing='value', handle_unknown='value')
209+
result = enc.fit_transform(train)
210+
original = enc.inverse_transform(result)
211+
212+
pd.testing.assert_frame_equal(train, original)
213+
214+
def test_inverse_transform_HaveNanInTrainAndHandleMissingReturnNan_ExpectReturnedWithNan(self):
215+
train = pd.DataFrame({'city': ['chicago', np.nan]})
216+
217+
enc = encoders.OneHotEncoder(handle_missing='return_nan', handle_unknown='value')
218+
result = enc.fit_transform(train)
219+
original = enc.inverse_transform(result)
220+
221+
pd.testing.assert_frame_equal(train, original)
222+
223+
def test_inverse_transform_BothFieldsAreReturnNanWithNan_ExpectValueError(self):
224+
train = pd.DataFrame({'city': ['chicago', np.nan]})
225+
test = pd.DataFrame({'city': ['chicago', 'los angeles']})
226+
227+
enc = encoders.OneHotEncoder(handle_missing='return_nan', handle_unknown='return_nan')
228+
enc.fit(train)
229+
result = enc.transform(test)
230+
231+
with warnings.catch_warnings(record=True) as w:
232+
enc.inverse_transform(result)
233+
234+
self.assertEqual(1, len(w))
235+
self.assertEqual('inverse_transform is not supported because transform impute '
236+
'the unknown category nan when encode city', str(w[0].message))
237+
238+
def test_inverse_transform_HaveMissingAndNoUnknown_ExpectInversed(self):
239+
train = pd.DataFrame({'city': ['chicago', np.nan]})
240+
test = pd.DataFrame({'city': ['chicago', 'los angeles']})
241+
242+
enc = encoders.OneHotEncoder(handle_missing='value', handle_unknown='return_nan')
243+
enc.fit(train)
244+
result = enc.transform(test)
245+
original = enc.inverse_transform(result)
246+
247+
pd.testing.assert_frame_equal(train, original)
248+
249+
def test_inverse_transform_HaveHandleMissingValueAndHandleUnknownReturnNan_ExpectBestInverse(self):
250+
train = pd.DataFrame({'city': ['chicago', np.nan]})
251+
test = pd.DataFrame({'city': ['chicago', np.nan, 'los angeles']})
252+
expected = pd.DataFrame({'city': ['chicago', np.nan, np.nan]})
253+
254+
enc = encoders.OneHotEncoder(handle_missing='value', handle_unknown='return_nan')
255+
enc.fit(train)
256+
result = enc.transform(test)
257+
original = enc.inverse_transform(result)
258+
259+
pd.testing.assert_frame_equal(expected, original)

0 commit comments

Comments
 (0)