Skip to content

Commit 5aa3313

Browse files
committed
use prefix in from_dummies
1 parent 6a408ea commit 5aa3313

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

pandas/core/arrays/categorical.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,17 @@ def __init__(
375375

376376
@classmethod
377377
def from_dummies(
378-
cls, dummies: "DataFrame", ordered: Optional[bool] = None
378+
cls,
379+
dummies: "DataFrame",
380+
ordered: Optional[bool] = None,
381+
prefix=None,
382+
prefix_sep="_",
379383
) -> "Categorical":
380384
"""Create a `Categorical` using a ``DataFrame`` of dummy variables.
381385
386+
Can use a subset of columns based on the ``prefix``
387+
and ``prefix_sep`` parameters.
388+
382389
The ``DataFrame`` must have no more than one truthy value per row.
383390
The columns of the ``DataFrame`` become the categories of the `Categorical`.
384391
A column whose header is NA will be dropped:
@@ -391,6 +398,13 @@ def from_dummies(
391398
Sparse dataframes are not supported.
392399
ordered : bool
393400
Whether or not this Categorical is ordered.
401+
prefix : optional str
402+
Only take columns whose names are strings starting
403+
with this prefix and ``prefix_sep``,
404+
stripping those elements from the resulting category names.
405+
prefix_sep : str, default "_"
406+
If ``prefix`` is not ``None``, use as the separator
407+
between the prefix and the final name of the category.
394408
395409
Raises
396410
------
@@ -433,6 +447,17 @@ def from_dummies(
433447
to_drop = dummies.columns[isna(dummies.columns.values)]
434448
if len(to_drop):
435449
dummies = dummies.drop(columns=to_drop)
450+
451+
if prefix is not None:
452+
pref = prefix + (prefix_sep or "")
453+
name_map = dict()
454+
to_keep = []
455+
for c in dummies.columns:
456+
if isinstance(c, str) and c.startswith(pref):
457+
to_keep.append(c)
458+
name_map[c] = c[len(pref) :]
459+
dummies = dummies[to_keep].rename(columns=name_map)
460+
436461
df = dummies.astype("boolean")
437462

438463
multicat_rows = df.sum(axis=1, skipna=False) > 1

0 commit comments

Comments
 (0)