Skip to content

Commit fc3bb18

Browse files
authored
Merge pull request #102 from janmotl/issue_96_and_97
Issue 96 and 97
2 parents 1eeb18e + 2219ed8 commit fc3bb18

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

category_encoders/basen.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,19 @@ def transform(self, X, override_return_df=False):
168168
if not self.cols:
169169
return X
170170

171+
original_cols = set(X.columns)
171172
X = self.ordinal_encoder.transform(X)
172173
X = self.basen_encode(X, cols=self.cols)
173174

174175
if self.drop_invariant:
175176
for col in self.drop_cols:
176177
X.drop(col, 1, inplace=True)
177178

178-
X.fillna(0.0, inplace=True)
179+
# impute missing values only in the generated columns
180+
current_cols = set(X.columns)
181+
fillna_cols = list(current_cols - (original_cols - set(self.cols)))
182+
X[fillna_cols] = X[fillna_cols].fillna(value=0.0)
183+
179184
if self.return_df or override_return_df:
180185
return X
181186
else:
@@ -299,13 +304,13 @@ def basen_to_interger(self, X, cols, base):
299304
out_cols = X.columns.values
300305

301306
for col in cols:
302-
col_list = [col0 for col0 in out_cols if col0.startswith(col)]
307+
col_list = [col0 for col0 in out_cols if str(col0).startswith(col)]
303308
for col0 in col_list:
304309
if any(X[col0].isnull()):
305310
raise ValueError("inverse_transform is not supported because transform impute"
306311
"the unknown category -1 when encode %s" % (col,))
307312
if base == 1:
308-
value_array = np.array([int(col0.split('_')[1]) for col0 in col_list])
313+
value_array = np.array([int(col0.split('_')[-1]) for col0 in col_list])
309314
else:
310315
len0 = len(col_list)
311316
value_array = np.array([base ** (len0 - 1 - i) for i in range(len0)])

category_encoders/one_hot.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,14 +297,15 @@ def reverse_dummies(self, X, cols):
297297
out_cols = X.columns.values
298298

299299
for col in cols:
300-
col_list = [col0 for col0 in out_cols if col0.startswith(col)]
300+
col_list = [col0 for col0 in out_cols if str(col0).startswith(col)]
301+
prefix_length = len(col)+1 # original column name plus underscore
301302
if self.use_cat_names:
302303
X[col] = 0
303304
for tran_col in col_list:
304-
val = tran_col.split('_')[1]
305+
val = tran_col[prefix_length:]
305306
X.loc[X[tran_col] == 1, col] = val
306307
else:
307-
value_array = np.array([int(col0.split('_')[1]) for col0 in col_list])
308+
value_array = np.array([int(col0[prefix_length:]) for col0 in col_list])
308309
X[col] = np.dot(X[col_list].values, value_array.T)
309310
out_cols = [col0 for col0 in out_cols if col0 not in col_list]
310311

0 commit comments

Comments
 (0)