Skip to content

Commit 715585d

Browse files
ENH: Add dtype argument to StringMethods get_dummies() (#59577)
1 parent 83fd9ba commit 715585d

File tree

8 files changed

+154
-27
lines changed

8 files changed

+154
-27
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Other enhancements
5555
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
5656
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
5757
- :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`)
58+
- :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`)
5859
- Multiplying two :class:`DateOffset` objects will now raise a ``TypeError`` instead of a ``RecursionError`` (:issue:`59442`)
5960
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
6061
- Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`)

pandas/core/arrays/arrow/array.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
is_list_like,
4242
is_numeric_dtype,
4343
is_scalar,
44+
pandas_dtype,
4445
)
4546
from pandas.core.dtypes.dtypes import DatetimeTZDtype
4647
from pandas.core.dtypes.missing import isna
@@ -2475,7 +2476,9 @@ def _str_findall(self, pat: str, flags: int = 0) -> Self:
24752476
result = self._apply_elementwise(predicate)
24762477
return type(self)(pa.chunked_array(result))
24772478

2478-
def _str_get_dummies(self, sep: str = "|"):
2479+
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
2480+
if dtype is None:
2481+
dtype = np.bool_
24792482
split = pc.split_pattern(self._pa_array, sep)
24802483
flattened_values = pc.list_flatten(split)
24812484
uniques = flattened_values.unique()
@@ -2485,7 +2488,15 @@ def _str_get_dummies(self, sep: str = "|"):
24852488
n_cols = len(uniques)
24862489
indices = pc.index_in(flattened_values, uniques_sorted).to_numpy()
24872490
indices = indices + np.arange(n_rows).repeat(lengths) * n_cols
2488-
dummies = np.zeros(n_rows * n_cols, dtype=np.bool_)
2491+
_dtype = pandas_dtype(dtype)
2492+
dummies_dtype: NpDtype
2493+
if isinstance(_dtype, np.dtype):
2494+
dummies_dtype = _dtype
2495+
else:
2496+
dummies_dtype = np.bool_
2497+
dummies = np.zeros(n_rows * n_cols, dtype=dummies_dtype)
2498+
if dtype == str:
2499+
dummies[:] = False
24892500
dummies[indices] = True
24902501
dummies = dummies.reshape((n_rows, n_cols))
24912502
result = type(self)(pa.array(list(dummies)))

pandas/core/arrays/categorical.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2681,11 +2681,11 @@ def _str_map(
26812681
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
26822682
return take_nd(result, codes, fill_value=na_value)
26832683

2684-
def _str_get_dummies(self, sep: str = "|"):
2684+
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
26852685
# sep may not be in categories. Just bail on this.
26862686
from pandas.core.arrays import NumpyExtensionArray
26872687

2688-
return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep)
2688+
return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep, dtype)
26892689

26902690
# ------------------------------------------------------------------------
26912691
# GroupBy Methods

pandas/core/arrays/string_arrow.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
ArrayLike,
5757
AxisInt,
5858
Dtype,
59+
NpDtype,
5960
Scalar,
6061
Self,
6162
npt,
@@ -425,12 +426,22 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
425426
return super()._str_find(sub, start, end)
426427
return ArrowStringArrayMixin._str_find(self, sub, start, end)
427428

428-
def _str_get_dummies(self, sep: str = "|"):
429-
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
429+
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
430+
if dtype is None:
431+
dtype = np.int64
432+
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(
433+
sep, dtype
434+
)
430435
if len(labels) == 0:
431-
return np.empty(shape=(0, 0), dtype=np.int64), labels
436+
return np.empty(shape=(0, 0), dtype=dtype), labels
432437
dummies = np.vstack(dummies_pa.to_numpy())
433-
return dummies.astype(np.int64, copy=False), labels
438+
_dtype = pandas_dtype(dtype)
439+
dummies_dtype: NpDtype
440+
if isinstance(_dtype, np.dtype):
441+
dummies_dtype = _dtype
442+
else:
443+
dummies_dtype = np.bool_
444+
return dummies.astype(dummies_dtype, copy=False), labels
434445

435446
def _convert_int_result(self, result):
436447
if self.dtype.na_value is np.nan:

pandas/core/strings/accessor.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pandas.core.dtypes.common import (
2727
ensure_object,
2828
is_bool_dtype,
29+
is_extension_array_dtype,
2930
is_integer,
3031
is_list_like,
3132
is_object_dtype,
@@ -54,6 +55,8 @@
5455
Iterator,
5556
)
5657

58+
from pandas._typing import NpDtype
59+
5760
from pandas import (
5861
DataFrame,
5962
Index,
@@ -2431,7 +2434,11 @@ def wrap(
24312434
return self._wrap_result(result)
24322435

24332436
@forbid_nonstring_types(["bytes"])
2434-
def get_dummies(self, sep: str = "|"):
2437+
def get_dummies(
2438+
self,
2439+
sep: str = "|",
2440+
dtype: NpDtype | None = None,
2441+
):
24352442
"""
24362443
Return DataFrame of dummy/indicator variables for Series.
24372444
@@ -2442,6 +2449,8 @@ def get_dummies(self, sep: str = "|"):
24422449
----------
24432450
sep : str, default "|"
24442451
String to split on.
2452+
dtype : dtype, default np.int64
2453+
Data type for new columns. Only a single dtype is allowed.
24452454
24462455
Returns
24472456
-------
@@ -2466,10 +2475,24 @@ def get_dummies(self, sep: str = "|"):
24662475
0 1 1 0
24672476
1 0 0 0
24682477
2 1 0 1
2478+
2479+
>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dtype=bool)
2480+
a b c
2481+
0 True True False
2482+
1 False False False
2483+
2 True False True
24692484
"""
2485+
from pandas.core.frame import DataFrame
2486+
24702487
# we need to cast to Series of strings as only that has all
24712488
# methods available for making the dummies...
2472-
result, name = self._data.array._str_get_dummies(sep)
2489+
result, name = self._data.array._str_get_dummies(sep, dtype)
2490+
if is_extension_array_dtype(dtype) or isinstance(dtype, ArrowDtype):
2491+
return self._wrap_result(
2492+
DataFrame(result, columns=name, dtype=dtype),
2493+
name=name,
2494+
returns_string=False,
2495+
)
24732496
return self._wrap_result(
24742497
result,
24752498
name=name,

pandas/core/strings/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717

1818
from pandas._typing import (
19+
NpDtype,
1920
Scalar,
2021
Self,
2122
)
@@ -163,7 +164,7 @@ def _str_wrap(self, width: int, **kwargs):
163164
pass
164165

165166
@abc.abstractmethod
166-
def _str_get_dummies(self, sep: str = "|"):
167+
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
167168
pass
168169

169170
@abc.abstractmethod

pandas/core/strings/object_array.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pandas._libs.ops as libops
1919
from pandas.util._exceptions import find_stack_level
2020

21+
from pandas.core.dtypes.common import pandas_dtype
2122
from pandas.core.dtypes.missing import isna
2223

2324
from pandas.core.strings.base import BaseStringArrayMethods
@@ -398,9 +399,11 @@ def _str_wrap(self, width: int, **kwargs):
398399
tw = textwrap.TextWrapper(**kwargs)
399400
return self._str_map(lambda s: "\n".join(tw.wrap(s)))
400401

401-
def _str_get_dummies(self, sep: str = "|"):
402+
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
402403
from pandas import Series
403404

405+
if dtype is None:
406+
dtype = np.int64
404407
arr = Series(self).fillna("")
405408
try:
406409
arr = sep + arr + sep
@@ -412,7 +415,13 @@ def _str_get_dummies(self, sep: str = "|"):
412415
tags.update(ts)
413416
tags2 = sorted(tags - {""})
414417

415-
dummies = np.empty((len(arr), len(tags2)), dtype=np.int64)
418+
_dtype = pandas_dtype(dtype)
419+
dummies_dtype: NpDtype
420+
if isinstance(_dtype, np.dtype):
421+
dummies_dtype = _dtype
422+
else:
423+
dummies_dtype = np.bool_
424+
dummies = np.empty((len(arr), len(tags2)), dtype=dummies_dtype)
416425

417426
def _isin(test_elements: str, element: str) -> bool:
418427
return element in test_elements

pandas/tests/strings/test_get_dummies.py

+85-14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import numpy as np
2+
import pytest
3+
4+
import pandas.util._test_decorators as td
25

36
from pandas import (
47
DataFrame,
@@ -8,6 +11,11 @@
811
_testing as tm,
912
)
1013

14+
try:
15+
import pyarrow as pa
16+
except ImportError:
17+
pa = None
18+
1119

1220
def test_get_dummies(any_string_dtype):
1321
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
@@ -32,22 +40,85 @@ def test_get_dummies_index():
3240
tm.assert_index_equal(result, expected)
3341

3442

35-
def test_get_dummies_with_name_dummy(any_string_dtype):
36-
# GH 12180
37-
# Dummies named 'name' should work as expected
38-
s = Series(["a", "b,name", "b"], dtype=any_string_dtype)
39-
result = s.str.get_dummies(",")
40-
expected = DataFrame([[1, 0, 0], [0, 1, 1], [0, 1, 0]], columns=["a", "b", "name"])
43+
# GH#47872
44+
@pytest.mark.parametrize(
45+
"dtype",
46+
[
47+
np.uint8,
48+
np.int16,
49+
np.uint16,
50+
np.int32,
51+
np.uint32,
52+
np.int64,
53+
np.uint64,
54+
bool,
55+
"Int8",
56+
"Int16",
57+
"Int32",
58+
"Int64",
59+
"boolean",
60+
],
61+
)
62+
def test_get_dummies_with_dtype(any_string_dtype, dtype):
63+
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
64+
result = s.str.get_dummies("|", dtype=dtype)
65+
expected = DataFrame(
66+
[[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"), dtype=dtype
67+
)
4168
tm.assert_frame_equal(result, expected)
4269

4370

44-
def test_get_dummies_with_name_dummy_index():
45-
# GH 12180
46-
# Dummies named 'name' should work as expected
47-
idx = Index(["a|b", "name|c", "b|name"])
48-
result = idx.str.get_dummies("|")
71+
# GH#47872
72+
@td.skip_if_no("pyarrow")
73+
@pytest.mark.parametrize(
74+
"dtype",
75+
[
76+
"int8[pyarrow]",
77+
"uint8[pyarrow]",
78+
"int16[pyarrow]",
79+
"uint16[pyarrow]",
80+
"int32[pyarrow]",
81+
"uint32[pyarrow]",
82+
"int64[pyarrow]",
83+
"uint64[pyarrow]",
84+
"bool[pyarrow]",
85+
],
86+
)
87+
def test_get_dummies_with_pyarrow_dtype(any_string_dtype, dtype):
88+
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
89+
result = s.str.get_dummies("|", dtype=dtype)
90+
expected = DataFrame(
91+
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
92+
columns=list("abc"),
93+
dtype=dtype,
94+
)
95+
tm.assert_frame_equal(result, expected)
4996

50-
expected = MultiIndex.from_tuples(
51-
[(1, 1, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1)], names=("a", "b", "c", "name")
97+
98+
# GH#47872
99+
def test_get_dummies_with_str_dtype(any_string_dtype):
100+
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
101+
result = s.str.get_dummies("|", dtype=str)
102+
expected = DataFrame(
103+
[["T", "T", "F"], ["T", "F", "T"], ["F", "F", "F"]],
104+
columns=list("abc"),
105+
dtype=str,
52106
)
53-
tm.assert_index_equal(result, expected)
107+
tm.assert_frame_equal(result, expected)
108+
109+
110+
# GH#47872
111+
@td.skip_if_no("pyarrow")
112+
def test_get_dummies_with_pa_str_dtype(any_string_dtype):
113+
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
114+
result = s.str.get_dummies("|", dtype="str[pyarrow]")
115+
expected = DataFrame(
116+
[
117+
["true", "true", "false"],
118+
["true", "false", "true"],
119+
["false", "false", "false"],
120+
],
121+
columns=list("abc"),
122+
dtype="str[pyarrow]",
123+
)
124+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)