@@ -380,6 +380,7 @@ def from_dummies(
380
380
ordered : Optional [bool ] = None ,
381
381
prefix = None ,
382
382
prefix_sep = "_" ,
383
+ fillna = None ,
383
384
) -> "Categorical" :
384
385
"""Create a `Categorical` using a ``DataFrame`` of dummy variables.
385
386
@@ -405,6 +406,9 @@ def from_dummies(
405
406
prefix_sep : str, default "_"
406
407
If ``prefix`` is not ``None``, use as the separator
407
408
between the prefix and the final name of the category.
409
+ fillna : optional bool, default None
410
+ How to handle NA values. If ``True`` or ``False``, NA is filled with that value.
411
+ If ``None``, raise a ValueError if there are any NA values.
408
412
409
413
Raises
410
414
------
@@ -444,37 +448,48 @@ def from_dummies(
444
448
...
445
449
ValueError: 1 record(s) belongs to multiple categories: [0]
446
450
"""
451
+ from pandas import Series
452
+
453
+ copied = False
447
454
to_drop = dummies .columns [isna (dummies .columns .values )]
448
455
if len (to_drop ):
449
456
dummies = dummies .drop (columns = to_drop )
457
+ copied = True
450
458
451
- if prefix is not None :
459
+ if prefix is None :
460
+ cats = dummies .columns
461
+ else :
452
462
pref = prefix + (prefix_sep or "" )
453
- name_map = dict ()
463
+ cats = []
454
464
to_keep = []
455
465
for c in dummies .columns :
456
466
if isinstance (c , str ) and c .startswith (pref ):
457
467
to_keep .append (c )
458
- name_map [ c ] = c [len (pref ) :]
459
- dummies = dummies [to_keep ]. rename ( columns = name_map )
468
+ cats . append ( c [len (pref ) :])
469
+ dummies = dummies [to_keep ]
460
470
461
471
df = dummies .astype ("boolean" )
472
+ if fillna is not None :
473
+ df = df .fillna (fillna , inplace = copied )
462
474
463
- multicat_rows = df .sum (axis = 1 , skipna = False ) > 1
475
+ row_totals = df .sum (axis = 1 , skipna = False )
476
+ if row_totals .isna ().any ():
477
+ raise ValueError ("Unhandled NA values in dummy array" )
478
+
479
+ multicat_rows = row_totals > 1
464
480
if multicat_rows .any ():
465
481
raise ValueError (
466
482
"{} record(s) belongs to multiple categories: {}" .format (
467
483
multicat_rows .sum (), list (df .index [multicat_rows ]),
468
484
)
469
485
)
470
486
471
- mult_by = np .arange (df .shape [1 ]) + 1
472
- # 000 000 0 -1
473
- # 010 020 2 1
474
- # 001 * 1,2,3 => 003 -> 3 -> 2 = correct codes
475
- # 100 100 1 0
476
- codes = ((df * mult_by ).sum (axis = 1 , skipna = False ) - 1 ).astype ("Int64" )
477
- return cls .from_codes (codes .fillna (- 1 ), df .columns .values , ordered = ordered )
487
+ codes = Series (np .full (len (row_totals ), np .nan ), index = df .index , dtype = "Int64" )
488
+ codes [row_totals == 0 ] = - 1
489
+ row_idx , code = np .nonzero (df )
490
+ codes [row_idx ] = code
491
+
492
+ return cls .from_codes (codes .fillna (- 1 ), cats , ordered = ordered )
478
493
479
494
def get_dummies (
480
495
self ,
0 commit comments