From 9ac286ac0cf893164b24e5dea432ed378651afce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?gyro=E6=B0=B8=E4=B8=8D=E6=8A=BD=E9=A3=8E?= <1247006353@qq.com> Date: Thu, 4 Aug 2022 15:01:55 +0800 Subject: [PATCH] ENH: Add `dtype` param for `pandas.Series.str.get_dummies` --- pandas/core/strings/accessor.py | 6 ++++-- pandas/core/strings/object_array.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index d50daad9a22b1..cb388b8b362d4 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -2136,7 +2136,7 @@ def wrap(self, width, **kwargs): return self._wrap_result(result) @forbid_nonstring_types(["bytes"]) - def get_dummies(self, sep="|"): + def get_dummies(self, sep="|", dtype=np.int64): """ Return DataFrame of dummy/indicator variables for Series. @@ -2147,6 +2147,8 @@ def get_dummies(self, sep="|"): ---------- sep : str, default "|" String to split on. + dtype : numpy.dtype, default np.int64 + The numpy dtype to use for the result ndarray. Returns ------- @@ -2174,7 +2176,7 @@ def get_dummies(self, sep="|"): """ # we need to cast to Series of strings as only that has all # methods available for making the dummies... - result, name = self._data.array._str_get_dummies(sep) + result, name = self._data.array._str_get_dummies(sep, dtype=dtype) return self._wrap_result( result, name=name, diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index f884264e9ab75..a992c078200dd 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -354,7 +354,7 @@ def _str_wrap(self, width, **kwargs): tw = textwrap.TextWrapper(**kwargs) return self._str_map(lambda s: "\n".join(tw.wrap(s))) - def _str_get_dummies(self, sep="|"): + def _str_get_dummies(self, sep="|", dtype=np.int64): from pandas import Series arr = Series(self).fillna("") @@ -368,7 +368,7 @@ def _str_get_dummies(self, sep="|"): tags.update(ts) tags2 = sorted(tags - {""}) - dummies = np.empty((len(arr), len(tags2)), dtype=np.int64) + dummies = np.empty((len(arr), len(tags2)), dtype=dtype) for i, t in enumerate(tags2): pat = sep + t + sep