Skip to content

ENH: Add dtype argument to StringMethods get_dummies() #59577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e6f9527
Add prefix, prefix_sep, dummy_na, and dtype args to StringMethods get…
aaronchucarroll Aug 21, 2024
dafb61d
Fix import issue
aaronchucarroll Aug 21, 2024
bb79ef2
Fix typing of dtype
aaronchucarroll Aug 21, 2024
24be84f
Fix NaN type issue
aaronchucarroll Aug 21, 2024
09b2fad
Support categorical string backend
aaronchucarroll Aug 21, 2024
50ed90c
Fix dtype type hints
aaronchucarroll Aug 21, 2024
9e95485
Add dtype to get_dummies docstring
aaronchucarroll Aug 21, 2024
9a47768
Fix get_dummies dtype docstring
aaronchucarroll Aug 21, 2024
0c94bff
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Aug 22, 2024
9702bf7
remove changes for unnecessary args
aaronchucarroll Sep 3, 2024
8793516
Merge branch 'stringmethods-get-dummies' of https://github.com/aaronc…
aaronchucarroll Sep 3, 2024
bad1038
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Sep 3, 2024
163fe09
parametrize dtype tests
aaronchucarroll Sep 5, 2024
3d75fdc
Merge branch 'stringmethods-get-dummies' of https://github.com/aaronc…
aaronchucarroll Sep 5, 2024
d68bece
support pyarrow and nullable dtypes
aaronchucarroll Sep 5, 2024
c2aa7d5
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Sep 5, 2024
0fd2401
fix pyarrow import error
aaronchucarroll Sep 5, 2024
920c865
skip pyarrow tests when not present
aaronchucarroll Sep 5, 2024
800f787
split pyarrow tests
aaronchucarroll Sep 5, 2024
d8149e6
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Sep 5, 2024
6cbc3e8
parametrize pyarrow tests
aaronchucarroll Sep 7, 2024
532e139
change var name to dummies_dtype
aaronchucarroll Sep 7, 2024
cd5c2ab
fix string issue
aaronchucarroll Sep 7, 2024
822b3f4
consolidate conditionals
aaronchucarroll Sep 7, 2024
ba05a8d
add tests for str and pyarrow strings
aaronchucarroll Sep 7, 2024
37dddb8
skip pyarrow string tests if not present
aaronchucarroll Sep 7, 2024
6fbe183
add info to whatsnew doc
aaronchucarroll Sep 9, 2024
87a1ee8
change func to meth in doc info
aaronchucarroll Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2540,7 +2540,9 @@ def _str_findall(self, pat: str, flags: int = 0) -> Self:
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
if dtype is None:
dtype = np.bool_
split = pc.split_pattern(self._pa_array, sep)
flattened_values = pc.list_flatten(split)
uniques = flattened_values.unique()
Expand All @@ -2550,7 +2552,7 @@ def _str_get_dummies(self, sep: str = "|"):
n_cols = len(uniques)
indices = pc.index_in(flattened_values, uniques_sorted).to_numpy()
indices = indices + np.arange(n_rows).repeat(lengths) * n_cols
dummies = np.zeros(n_rows * n_cols, dtype=np.bool_)
dummies = np.zeros(n_rows * n_cols, dtype=dtype)
dummies[indices] = True
dummies = dummies.reshape((n_rows, n_cols))
result = type(self)(pa.array(list(dummies)))
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2681,11 +2681,11 @@ def _str_map(
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
return take_nd(result, codes, fill_value=na_value)

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
# sep may not be in categories. Just bail on this.
from pandas.core.arrays import NumpyExtensionArray

return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep)
return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep, dtype)

# ------------------------------------------------------------------------
# GroupBy Methods
Expand Down
13 changes: 9 additions & 4 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ArrayLike,
AxisInt,
Dtype,
NpDtype,
Scalar,
Self,
npt,
Expand Down Expand Up @@ -461,12 +462,16 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
return super()._str_find(sub, start, end)
return self._convert_int_result(result)

def _str_get_dummies(self, sep: str = "|"):
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
if dtype is None:
dtype = np.int64
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(
sep, dtype
)
if len(labels) == 0:
return np.empty(shape=(0, 0), dtype=np.int64), labels
return np.empty(shape=(0, 0), dtype=dtype), labels
dummies = np.vstack(dummies_pa.to_numpy())
return dummies.astype(np.int64, copy=False), labels
return dummies.astype(dtype, copy=False), labels

def _convert_int_result(self, result):
if self.dtype.na_value is np.nan:
Expand Down
18 changes: 16 additions & 2 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
Iterator,
)

from pandas._typing import NpDtype

from pandas import (
DataFrame,
Index,
Expand Down Expand Up @@ -2431,7 +2433,11 @@ def wrap(
return self._wrap_result(result)

@forbid_nonstring_types(["bytes"])
def get_dummies(self, sep: str = "|"):
def get_dummies(
self,
sep: str = "|",
dtype: NpDtype | None = None,
):
"""
Return DataFrame of dummy/indicator variables for Series.

Expand All @@ -2442,6 +2448,8 @@ def get_dummies(self, sep: str = "|"):
----------
sep : str, default "|"
String to split on.
dtype : dtype, default np.int64
Data type for new columns. Only a single dtype is allowed.

Returns
-------
Expand All @@ -2466,10 +2474,16 @@ def get_dummies(self, sep: str = "|"):
0 1 1 0
1 0 0 0
2 1 0 1

>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dtype=bool)
a b c
0 True True False
1 False False False
2 True False True
"""
# we need to cast to Series of strings as only that has all
# methods available for making the dummies...
result, name = self._data.array._str_get_dummies(sep)
result, name = self._data.array._str_get_dummies(sep, dtype)
return self._wrap_result(
result,
name=name,
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/strings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re

from pandas._typing import (
NpDtype,
Scalar,
Self,
)
Expand Down Expand Up @@ -163,7 +164,7 @@ def _str_wrap(self, width: int, **kwargs):
pass

@abc.abstractmethod
def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
pass

@abc.abstractmethod
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/strings/object_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,11 @@ def _str_wrap(self, width: int, **kwargs):
tw = textwrap.TextWrapper(**kwargs)
return self._str_map(lambda s: "\n".join(tw.wrap(s)))

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
from pandas import Series

if dtype is None:
dtype = np.int64
arr = Series(self).fillna("")
try:
arr = sep + arr + sep
Expand All @@ -412,7 +414,7 @@ def _str_get_dummies(self, sep: str = "|"):
tags.update(ts)
tags2 = sorted(tags - {""})

dummies = np.empty((len(arr), len(tags2)), dtype=np.int64)
dummies = np.empty((len(arr), len(tags2)), dtype=dtype)

def _isin(test_elements: str, element: str) -> bool:
return element in test_elements
Expand Down
100 changes: 86 additions & 14 deletions pandas/tests/strings/test_get_dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,94 @@ def test_get_dummies_index():
tm.assert_index_equal(result, expected)


def test_get_dummies_with_name_dummy(any_string_dtype):
# GH 12180
# Dummies named 'name' should work as expected
s = Series(["a", "b,name", "b"], dtype=any_string_dtype)
result = s.str.get_dummies(",")
expected = DataFrame([[1, 0, 0], [0, 1, 1], [0, 1, 0]], columns=["a", "b", "name"])
def test_get_dummies_int8_dtype():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you parametrize these tests with dtype.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, can you add dtype=str, PyArrow, and nullable dtypes (e.g. "Int64"). Specifying PyArrow and nullable dtypes currently fails:

ser = pd.Series(["a|b", "a", "a|c"], dtype="string[pyarrow]")
ser.str.get_dummies(dtype=pd.ArrowDtype(pa.int64()))

but is successful with pd.get_dummies

pd.get_dummies(ser, dtype=pd.ArrowDtype(pa.int64()))

I think this will need to be fixed. You may find it necessary to have multiple tests - perhaps one for NumPy (which are already present), one for str, one for PyArrow etc. But just try to consolidate with pytest.parametrize as much as is reasonable.

s = Series(["1|2", "1|3", np.nan], dtype="string")
result = s.str.get_dummies("|", dtype=np.int8)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("123"), dtype=np.int8
)
tm.assert_frame_equal(result, expected)
assert (result.dtypes == np.int8).all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to assert this, it is covered by assert_frame_equal.



def test_get_dummies_with_name_dummy_index():
# GH 12180
# Dummies named 'name' should work as expected
idx = Index(["a|b", "name|c", "b|name"])
result = idx.str.get_dummies("|")
def test_get_dummies_uint8_dtype():
s = Series(["a|b", "a|c", np.nan], dtype="string")
result = s.str.get_dummies("|", dtype=np.uint8)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"), dtype=np.uint8
)
tm.assert_frame_equal(result, expected)
assert (result.dtypes == np.uint8).all()

expected = MultiIndex.from_tuples(
[(1, 1, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1)], names=("a", "b", "c", "name")

def test_get_dummies_int16_dtype():
s = Series(["a|b", "a|c", np.nan], dtype="string")
result = s.str.get_dummies("|", dtype=np.int16)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"), dtype=np.int16
)
tm.assert_index_equal(result, expected)
tm.assert_frame_equal(result, expected)
assert (result.dtypes == np.int16).all()


def test_get_dummies_uint16_dtype():
s = Series(["a|b", "a|c", np.nan], dtype="string")
result = s.str.get_dummies("|", dtype=np.uint16)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"), dtype=np.uint16
)
tm.assert_frame_equal(result, expected)
assert (result.dtypes == np.uint16).all()


def test_get_dummies_int32_dtype():
s = Series(["x|y", "x|z", np.nan], dtype="string")
result = s.str.get_dummies("|", dtype=np.int32)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("xyz"), dtype=np.int32
)
tm.assert_frame_equal(result, expected)
assert (result.dtypes == np.int32).all()


def test_get_dummies_uint32_dtype():
s = Series(["x|y", "x|z", np.nan], dtype="string")
result = s.str.get_dummies("|", dtype=np.uint32)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("xyz"), dtype=np.uint32
)
tm.assert_frame_equal(result, expected)
assert (result.dtypes == np.uint32).all()


def test_get_dummies_int64_dtype():
s = Series(["foo|bar", "foo|baz", np.nan], dtype="string")
result = s.str.get_dummies("|", dtype=np.int64)
expected = DataFrame(
[[1, 0, 1], [0, 1, 1], [0, 0, 0]], columns=["bar", "baz", "foo"], dtype=np.int64
)
tm.assert_frame_equal(result, expected)
assert (result.dtypes == np.int64).all()


def test_get_dummies_uint64_dtype():
s = Series(["foo|bar", "foo|baz", np.nan], dtype="string")
result = s.str.get_dummies("|", dtype=np.uint64)
expected = DataFrame(
[[1, 0, 1], [0, 1, 1], [0, 0, 0]],
columns=["bar", "baz", "foo"],
dtype=np.uint64,
)
tm.assert_frame_equal(result, expected)
assert (result.dtypes == np.uint64).all()


def test_get_dummies_bool_dtype():
s = Series(["a|b", "a|c", np.nan], dtype="string")
result = s.str.get_dummies("|", dtype=bool)
expected = DataFrame(
[[True, True, False], [True, False, True], [False, False, False]],
columns=["a", "b", "c"],
)
tm.assert_frame_equal(result, expected)
assert (result.dtypes == bool).all()