Skip to content

ENH: Add dtype argument to StringMethods get_dummies() #59577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e6f9527
Add prefix, prefix_sep, dummy_na, and dtype args to StringMethods get…
aaronchucarroll Aug 21, 2024
dafb61d
Fix import issue
aaronchucarroll Aug 21, 2024
bb79ef2
Fix typing of dtype
aaronchucarroll Aug 21, 2024
24be84f
Fix NaN type issue
aaronchucarroll Aug 21, 2024
09b2fad
Support categorical string backend
aaronchucarroll Aug 21, 2024
50ed90c
Fix dtype type hints
aaronchucarroll Aug 21, 2024
9e95485
Add dtype to get_dummies docstring
aaronchucarroll Aug 21, 2024
9a47768
Fix get_dummies dtype docstring
aaronchucarroll Aug 21, 2024
0c94bff
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Aug 22, 2024
9702bf7
remove changes for unnecessary args
aaronchucarroll Sep 3, 2024
8793516
Merge branch 'stringmethods-get-dummies' of https://github.com/aaronc…
aaronchucarroll Sep 3, 2024
bad1038
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Sep 3, 2024
163fe09
parametrize dtype tests
aaronchucarroll Sep 5, 2024
3d75fdc
Merge branch 'stringmethods-get-dummies' of https://github.com/aaronc…
aaronchucarroll Sep 5, 2024
d68bece
support pyarrow and nullable dtypes
aaronchucarroll Sep 5, 2024
c2aa7d5
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Sep 5, 2024
0fd2401
fix pyarrow import error
aaronchucarroll Sep 5, 2024
920c865
skip pyarrow tests when not present
aaronchucarroll Sep 5, 2024
800f787
split pyarrow tests
aaronchucarroll Sep 5, 2024
d8149e6
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Sep 5, 2024
6cbc3e8
parametrize pyarrow tests
aaronchucarroll Sep 7, 2024
532e139
change var name to dummies_dtype
aaronchucarroll Sep 7, 2024
cd5c2ab
fix string issue
aaronchucarroll Sep 7, 2024
822b3f4
consolidate conditionals
aaronchucarroll Sep 7, 2024
ba05a8d
add tests for str and pyarrow strings
aaronchucarroll Sep 7, 2024
37dddb8
skip pyarrow string tests if not present
aaronchucarroll Sep 7, 2024
6fbe183
add info to whatsnew doc
aaronchucarroll Sep 9, 2024
87a1ee8
change func to meth in doc info
aaronchucarroll Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_list_like,
is_numeric_dtype,
is_scalar,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.missing import isna
Expand Down Expand Up @@ -2513,7 +2514,9 @@ def _str_findall(self, pat: str, flags: int = 0) -> Self:
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
if dtype is None:
dtype = np.bool_
split = pc.split_pattern(self._pa_array, sep)
flattened_values = pc.list_flatten(split)
uniques = flattened_values.unique()
Expand All @@ -2523,7 +2526,15 @@ def _str_get_dummies(self, sep: str = "|"):
n_cols = len(uniques)
indices = pc.index_in(flattened_values, uniques_sorted).to_numpy()
indices = indices + np.arange(n_rows).repeat(lengths) * n_cols
dummies = np.zeros(n_rows * n_cols, dtype=np.bool_)
_dtype = pandas_dtype(dtype)
dummies_dtype: NpDtype
if isinstance(_dtype, np.dtype):
dummies_dtype = _dtype
else:
dummies_dtype = np.bool_
dummies = np.zeros(n_rows * n_cols, dtype=dummies_dtype)
if dtype == str:
dummies[:] = False
Comment on lines +2536 to +2537
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just put this logic before the dummies creation i.e.

if dtype == str:
    dummies_dtype = np.bool_
dummies = ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string types do not need to use a dummy dtype, they can handle boolean values. The issue is with the str type interaction with np.zeroes(), where it considers the zero value to be the empty string.

dummies[indices] = True
dummies = dummies.reshape((n_rows, n_cols))
result = type(self)(pa.array(list(dummies)))
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2681,11 +2681,11 @@ def _str_map(
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
return take_nd(result, codes, fill_value=na_value)

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
# sep may not be in categories. Just bail on this.
from pandas.core.arrays import NumpyExtensionArray

return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep)
return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep, dtype)

# ------------------------------------------------------------------------
# GroupBy Methods
Expand Down
19 changes: 15 additions & 4 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ArrayLike,
AxisInt,
Dtype,
NpDtype,
Scalar,
Self,
npt,
Expand Down Expand Up @@ -435,12 +436,22 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
return super()._str_find(sub, start, end)
return self._convert_int_result(result)

def _str_get_dummies(self, sep: str = "|"):
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
if dtype is None:
dtype = np.int64
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(
sep, dtype
)
if len(labels) == 0:
return np.empty(shape=(0, 0), dtype=np.int64), labels
return np.empty(shape=(0, 0), dtype=dtype), labels
dummies = np.vstack(dummies_pa.to_numpy())
return dummies.astype(np.int64, copy=False), labels
_dtype = pandas_dtype(dtype)
dummies_dtype: NpDtype
if isinstance(_dtype, np.dtype):
dummies_dtype = _dtype
else:
dummies_dtype = np.bool_
return dummies.astype(dummies_dtype, copy=False), labels

def _convert_int_result(self, result):
if self.dtype.na_value is np.nan:
Expand Down
27 changes: 25 additions & 2 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pandas.core.dtypes.common import (
ensure_object,
is_bool_dtype,
is_extension_array_dtype,
is_integer,
is_list_like,
is_object_dtype,
Expand Down Expand Up @@ -54,6 +55,8 @@
Iterator,
)

from pandas._typing import NpDtype

from pandas import (
DataFrame,
Index,
Expand Down Expand Up @@ -2431,7 +2434,11 @@ def wrap(
return self._wrap_result(result)

@forbid_nonstring_types(["bytes"])
def get_dummies(self, sep: str = "|"):
def get_dummies(
self,
sep: str = "|",
dtype: NpDtype | None = None,
):
"""
Return DataFrame of dummy/indicator variables for Series.

Expand All @@ -2442,6 +2449,8 @@ def get_dummies(self, sep: str = "|"):
----------
sep : str, default "|"
String to split on.
dtype : dtype, default np.int64
Data type for new columns. Only a single dtype is allowed.

Returns
-------
Expand All @@ -2466,10 +2475,24 @@ def get_dummies(self, sep: str = "|"):
0 1 1 0
1 0 0 0
2 1 0 1

>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dtype=bool)
a b c
0 True True False
1 False False False
2 True False True
"""
from pandas.core.frame import DataFrame

# we need to cast to Series of strings as only that has all
# methods available for making the dummies...
result, name = self._data.array._str_get_dummies(sep)
result, name = self._data.array._str_get_dummies(sep, dtype)
if is_extension_array_dtype(dtype) or isinstance(dtype, ArrowDtype):
return self._wrap_result(
DataFrame(result, columns=name, dtype=dtype),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use _wrap_result(result, name=name, dtype=dtype, expand=True) here instead to avoid the DataFrame import

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this change causes failures because the numpy.ndarray does not take non-numpy dtypes. It doesn't seem like _wrap_result handles this case.

name=name,
returns_string=False,
)
return self._wrap_result(
result,
name=name,
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/strings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re

from pandas._typing import (
NpDtype,
Scalar,
Self,
)
Expand Down Expand Up @@ -163,7 +164,7 @@ def _str_wrap(self, width: int, **kwargs):
pass

@abc.abstractmethod
def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
pass

@abc.abstractmethod
Expand Down
13 changes: 11 additions & 2 deletions pandas/core/strings/object_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas._libs.ops as libops
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.common import pandas_dtype
from pandas.core.dtypes.missing import isna

from pandas.core.strings.base import BaseStringArrayMethods
Expand Down Expand Up @@ -398,9 +399,11 @@ def _str_wrap(self, width: int, **kwargs):
tw = textwrap.TextWrapper(**kwargs)
return self._str_map(lambda s: "\n".join(tw.wrap(s)))

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
from pandas import Series

if dtype is None:
dtype = np.int64
arr = Series(self).fillna("")
try:
arr = sep + arr + sep
Expand All @@ -412,7 +415,13 @@ def _str_get_dummies(self, sep: str = "|"):
tags.update(ts)
tags2 = sorted(tags - {""})

dummies = np.empty((len(arr), len(tags2)), dtype=np.int64)
_dtype = pandas_dtype(dtype)
dummies_dtype: NpDtype
if isinstance(_dtype, np.dtype):
dummies_dtype = _dtype
else:
dummies_dtype = np.bool_
dummies = np.empty((len(arr), len(tags2)), dtype=dummies_dtype)

def _isin(test_elements: str, element: str) -> bool:
return element in test_elements
Expand Down
99 changes: 85 additions & 14 deletions pandas/tests/strings/test_get_dummies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

from pandas import (
DataFrame,
Expand All @@ -8,6 +11,11 @@
_testing as tm,
)

try:
import pyarrow as pa
except ImportError:
pa = None


def test_get_dummies(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
Expand All @@ -32,22 +40,85 @@ def test_get_dummies_index():
tm.assert_index_equal(result, expected)


def test_get_dummies_with_name_dummy(any_string_dtype):
# GH 12180
# Dummies named 'name' should work as expected
s = Series(["a", "b,name", "b"], dtype=any_string_dtype)
result = s.str.get_dummies(",")
expected = DataFrame([[1, 0, 0], [0, 1, 1], [0, 1, 0]], columns=["a", "b", "name"])
# GH#47872
@pytest.mark.parametrize(
"dtype",
[
np.uint8,
np.int16,
np.uint16,
np.int32,
np.uint32,
np.int64,
np.uint64,
bool,
"Int8",
"Int16",
"Int32",
"Int64",
"boolean",
],
)
def test_get_dummies_with_dtype(any_string_dtype, dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies("|", dtype=dtype)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"), dtype=dtype
)
tm.assert_frame_equal(result, expected)


def test_get_dummies_with_name_dummy_index():
# GH 12180
# Dummies named 'name' should work as expected
idx = Index(["a|b", "name|c", "b|name"])
result = idx.str.get_dummies("|")
# GH#47872
@td.skip_if_no("pyarrow")
@pytest.mark.parametrize(
"dtype",
[
"int8[pyarrow]",
"uint8[pyarrow]",
"int16[pyarrow]",
"uint16[pyarrow]",
"int32[pyarrow]",
"uint32[pyarrow]",
"int64[pyarrow]",
"uint64[pyarrow]",
"bool[pyarrow]",
],
)
def test_get_dummies_with_pyarrow_dtype(any_string_dtype, dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies("|", dtype=dtype)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
columns=list("abc"),
dtype=dtype,
)
tm.assert_frame_equal(result, expected)

expected = MultiIndex.from_tuples(
[(1, 1, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1)], names=("a", "b", "c", "name")

# GH#47872
def test_get_dummies_with_str_dtype(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies("|", dtype=str)
expected = DataFrame(
[["T", "T", "F"], ["T", "F", "T"], ["F", "F", "F"]],
columns=list("abc"),
dtype=str,
)
tm.assert_index_equal(result, expected)
tm.assert_frame_equal(result, expected)


# GH#47872
@td.skip_if_no("pyarrow")
def test_get_dummies_with_pa_str_dtype(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies("|", dtype="str[pyarrow]")
expected = DataFrame(
[
["true", "true", "false"],
["true", "false", "true"],
["false", "false", "false"],
],
columns=list("abc"),
dtype="str[pyarrow]",
)
tm.assert_frame_equal(result, expected)