Skip to content

Commit 765f001

Browse files
committed
Lower-memory impl for Categorical.from_dummies
1 parent 7cb3f1b commit 765f001

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

pandas/core/arrays/categorical.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def from_dummies(
380380
ordered: Optional[bool] = None,
381381
prefix=None,
382382
prefix_sep="_",
383+
fillna=None,
383384
) -> "Categorical":
384385
"""Create a `Categorical` using a ``DataFrame`` of dummy variables.
385386
@@ -405,6 +406,9 @@ def from_dummies(
405406
prefix_sep : str, default "_"
406407
If ``prefix`` is not ``None``, use as the separator
407408
between the prefix and the final name of the category.
409+
fillna : optional bool, default None
410+
How to handle NA values. If ``True`` or ``False``, NA is filled with that value.
411+
If ``None``, raise a ValueError if there are any NA values.
408412
409413
Raises
410414
------
@@ -444,37 +448,48 @@ def from_dummies(
444448
...
445449
ValueError: 1 record(s) belongs to multiple categories: [0]
446450
"""
451+
from pandas import Series
452+
453+
copied = False
447454
to_drop = dummies.columns[isna(dummies.columns.values)]
448455
if len(to_drop):
449456
dummies = dummies.drop(columns=to_drop)
457+
copied = True
450458

451-
if prefix is not None:
459+
if prefix is None:
460+
cats = dummies.columns
461+
else:
452462
pref = prefix + (prefix_sep or "")
453-
name_map = dict()
463+
cats = []
454464
to_keep = []
455465
for c in dummies.columns:
456466
if isinstance(c, str) and c.startswith(pref):
457467
to_keep.append(c)
458-
name_map[c] = c[len(pref) :]
459-
dummies = dummies[to_keep].rename(columns=name_map)
468+
cats.append(c[len(pref) :])
469+
dummies = dummies[to_keep]
460470

461471
df = dummies.astype("boolean")
472+
if fillna is not None:
473+
df = df.fillna(fillna, inplace=copied)
462474

463-
multicat_rows = df.sum(axis=1, skipna=False) > 1
475+
row_totals = df.sum(axis=1, skipna=False)
476+
if row_totals.isna().any():
477+
raise ValueError("Unhandled NA values in dummy array")
478+
479+
multicat_rows = row_totals > 1
464480
if multicat_rows.any():
465481
raise ValueError(
466482
"{} record(s) belongs to multiple categories: {}".format(
467483
multicat_rows.sum(), list(df.index[multicat_rows]),
468484
)
469485
)
470486

471-
mult_by = np.arange(df.shape[1]) + 1
472-
# 000 000 0 -1
473-
# 010 020 2 1
474-
# 001 * 1,2,3 => 003 -> 3 -> 2 = correct codes
475-
# 100 100 1 0
476-
codes = ((df * mult_by).sum(axis=1, skipna=False) - 1).astype("Int64")
477-
return cls.from_codes(codes.fillna(-1), df.columns.values, ordered=ordered)
487+
codes = Series(np.full(len(row_totals), np.nan), index=df.index, dtype="Int64")
488+
codes[row_totals == 0] = -1
489+
row_idx, code = np.nonzero(df)
490+
codes[row_idx] = code
491+
492+
return cls.from_codes(codes.fillna(-1), cats, ordered=ordered)
478493

479494
def get_dummies(
480495
self,

0 commit comments

Comments
 (0)