Skip to content

Commit 3b87703

Browse files
TomAugspurgerJustinZhengBC
authored andcommitted
BUG: astype fill_value for SparseArray.astype (pandas-dev#23547)
1 parent d8826bf commit 3b87703

File tree

3 files changed

+139
-14
lines changed

3 files changed

+139
-14
lines changed

pandas/core/arrays/sparse.py

+91-14
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,83 @@ def is_dtype(cls, dtype):
284284
return True
285285
return isinstance(dtype, np.dtype) or dtype == 'Sparse'
286286

287+
def update_dtype(self, dtype):
288+
"""Convert the SparseDtype to a new dtype.
289+
290+
This takes care of converting the ``fill_value``.
291+
292+
Parameters
293+
----------
294+
dtype : Union[str, numpy.dtype, SparseDtype]
295+
The new dtype to use.
296+
297+
* For a SparseDtype, it is simply returned
298+
* For a NumPy dtype (or str), the current fill value
299+
is converted to the new dtype, and a SparseDtype
300+
with `dtype` and the new fill value is returned.
301+
302+
Returns
303+
-------
304+
SparseDtype
305+
A new SparseDtype with the corret `dtype` and fill value
306+
for that `dtype`.
307+
308+
Raises
309+
------
310+
ValueError
311+
When the current fill value cannot be converted to the
312+
new `dtype` (e.g. trying to convert ``np.nan`` to an
313+
integer dtype).
314+
315+
316+
Examples
317+
--------
318+
>>> SparseDtype(int, 0).update_dtype(float)
319+
Sparse[float64, 0.0]
320+
321+
>>> SparseDtype(int, 1).update_dtype(SparseDtype(float, np.nan))
322+
Sparse[float64, nan]
323+
"""
324+
cls = type(self)
325+
dtype = pandas_dtype(dtype)
326+
327+
if not isinstance(dtype, cls):
328+
fill_value = astype_nansafe(np.array(self.fill_value),
329+
dtype).item()
330+
dtype = cls(dtype, fill_value=fill_value)
331+
332+
return dtype
333+
334+
@property
335+
def _subtype_with_str(self):
336+
"""
337+
Whether the SparseDtype's subtype should be considered ``str``.
338+
339+
Typically, pandas will store string data in an object-dtype array.
340+
When converting values to a dtype, e.g. in ``.astype``, we need to
341+
be more specific, we need the actual underlying type.
342+
343+
Returns
344+
-------
345+
346+
>>> SparseDtype(int, 1)._subtype_with_str
347+
dtype('int64')
348+
349+
>>> SparseDtype(object, 1)._subtype_with_str
350+
dtype('O')
351+
352+
>>> dtype = SparseDtype(str, '')
353+
>>> dtype.subtype
354+
dtype('O')
355+
356+
>>> dtype._subtype_with_str
357+
str
358+
"""
359+
if isinstance(self.fill_value, compat.string_types):
360+
return type(self.fill_value)
361+
return self.subtype
362+
363+
287364
# ----------------------------------------------------------------------------
288365
# Array
289366

@@ -614,7 +691,7 @@ def __array__(self, dtype=None, copy=True):
614691
# Can't put pd.NaT in a datetime64[ns]
615692
fill_value = np.datetime64('NaT')
616693
try:
617-
dtype = np.result_type(self.sp_values.dtype, fill_value)
694+
dtype = np.result_type(self.sp_values.dtype, type(fill_value))
618695
except TypeError:
619696
dtype = object
620697

@@ -996,7 +1073,7 @@ def _take_with_fill(self, indices, fill_value=None):
9961073
if len(self) == 0:
9971074
# Empty... Allow taking only if all empty
9981075
if (indices == -1).all():
999-
dtype = np.result_type(self.sp_values, fill_value)
1076+
dtype = np.result_type(self.sp_values, type(fill_value))
10001077
taken = np.empty_like(indices, dtype=dtype)
10011078
taken.fill(fill_value)
10021079
return taken
@@ -1009,7 +1086,7 @@ def _take_with_fill(self, indices, fill_value=None):
10091086
if self.sp_index.npoints == 0:
10101087
# Avoid taking from the empty self.sp_values
10111088
taken = np.full(sp_indexer.shape, fill_value=fill_value,
1012-
dtype=np.result_type(fill_value))
1089+
dtype=np.result_type(type(fill_value)))
10131090
else:
10141091
taken = self.sp_values.take(sp_indexer)
10151092

@@ -1030,12 +1107,13 @@ def _take_with_fill(self, indices, fill_value=None):
10301107
result_type = taken.dtype
10311108

10321109
if m0.any():
1033-
result_type = np.result_type(result_type, self.fill_value)
1110+
result_type = np.result_type(result_type,
1111+
type(self.fill_value))
10341112
taken = taken.astype(result_type)
10351113
taken[old_fill_indices] = self.fill_value
10361114

10371115
if m1.any():
1038-
result_type = np.result_type(result_type, fill_value)
1116+
result_type = np.result_type(result_type, type(fill_value))
10391117
taken = taken.astype(result_type)
10401118
taken[new_fill_indices] = fill_value
10411119

@@ -1061,7 +1139,7 @@ def _take_without_fill(self, indices):
10611139
# edge case in take...
10621140
# I think just return
10631141
out = np.full(indices.shape, self.fill_value,
1064-
dtype=np.result_type(self.fill_value))
1142+
dtype=np.result_type(type(self.fill_value)))
10651143
arr, sp_index, fill_value = make_sparse(out,
10661144
fill_value=self.fill_value)
10671145
return type(self)(arr, sparse_index=sp_index,
@@ -1073,7 +1151,7 @@ def _take_without_fill(self, indices):
10731151

10741152
if fillable.any():
10751153
# TODO: may need to coerce array to fill value
1076-
result_type = np.result_type(taken, self.fill_value)
1154+
result_type = np.result_type(taken, type(self.fill_value))
10771155
taken = taken.astype(result_type)
10781156
taken[fillable] = self.fill_value
10791157

@@ -1093,7 +1171,9 @@ def _concat_same_type(cls, to_concat):
10931171

10941172
fill_value = fill_values[0]
10951173

1096-
if len(set(fill_values)) > 1:
1174+
# np.nan isn't a singleton, so we may end up with multiple
1175+
# NaNs here, so we ignore tha all NA case too.
1176+
if not (len(set(fill_values)) == 1 or isna(fill_values).all()):
10971177
warnings.warn("Concatenating sparse arrays with multiple fill "
10981178
"values: '{}'. Picking the first and "
10991179
"converting the rest.".format(fill_values),
@@ -1212,13 +1292,10 @@ def astype(self, dtype=None, copy=True):
12121292
IntIndex
12131293
Indices: array([2, 3], dtype=int32)
12141294
"""
1215-
dtype = pandas_dtype(dtype)
1216-
1217-
if not isinstance(dtype, SparseDtype):
1218-
dtype = SparseDtype(dtype, fill_value=self.fill_value)
1219-
1295+
dtype = self.dtype.update_dtype(dtype)
1296+
subtype = dtype._subtype_with_str
12201297
sp_values = astype_nansafe(self.sp_values,
1221-
dtype.subtype,
1298+
subtype,
12221299
copy=copy)
12231300
if sp_values is self.sp_values and copy:
12241301
sp_values = sp_values.copy()

pandas/tests/arrays/sparse/test_array.py

+28
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,34 @@ def test_astype_all(self, any_real_dtype):
477477
tm.assert_numpy_array_equal(np.asarray(res.values),
478478
vals.astype(typ))
479479

480+
@pytest.mark.parametrize('array, dtype, expected', [
481+
(SparseArray([0, 1]), 'float',
482+
SparseArray([0., 1.], dtype=SparseDtype(float, 0.0))),
483+
(SparseArray([0, 1]), bool, SparseArray([False, True])),
484+
(SparseArray([0, 1], fill_value=1), bool,
485+
SparseArray([False, True], dtype=SparseDtype(bool, True))),
486+
pytest.param(
487+
SparseArray([0, 1]), 'datetime64[ns]',
488+
SparseArray(np.array([0, 1], dtype='datetime64[ns]'),
489+
dtype=SparseDtype('datetime64[ns]',
490+
pd.Timestamp('1970'))),
491+
marks=[pytest.mark.xfail(reason="NumPy-7619", strict=True)],
492+
),
493+
(SparseArray([0, 1, 10]), str,
494+
SparseArray(['0', '1', '10'], dtype=SparseDtype(str, '0'))),
495+
(SparseArray(['10', '20']), float, SparseArray([10.0, 20.0])),
496+
(SparseArray([0, 1, 0]), object,
497+
SparseArray([0, 1, 0], dtype=SparseDtype(object, 0))),
498+
])
499+
def test_astype_more(self, array, dtype, expected):
500+
result = array.astype(dtype)
501+
tm.assert_sp_array_equal(result, expected)
502+
503+
def test_astype_nan_raises(self):
504+
arr = SparseArray([1.0, np.nan])
505+
with pytest.raises(ValueError, match='Cannot convert non-finite'):
506+
arr.astype(int)
507+
480508
def test_set_fill_value(self):
481509
arr = SparseArray([1., np.nan, 2.], fill_value=np.nan)
482510
arr.fill_value = 2

pandas/tests/arrays/sparse/test_dtype.py

+20
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,23 @@ def test_parse_subtype(string, expected):
139139
def test_construct_from_string_fill_value_raises(string):
140140
with pytest.raises(TypeError, match='fill_value in the string is not'):
141141
SparseDtype.construct_from_string(string)
142+
143+
144+
@pytest.mark.parametrize('original, dtype, expected', [
145+
(SparseDtype(int, 0), float, SparseDtype(float, 0.0)),
146+
(SparseDtype(int, 1), float, SparseDtype(float, 1.0)),
147+
(SparseDtype(int, 1), str, SparseDtype(object, '1')),
148+
(SparseDtype(float, 1.5), int, SparseDtype(int, 1)),
149+
])
150+
def test_update_dtype(original, dtype, expected):
151+
result = original.update_dtype(dtype)
152+
assert result == expected
153+
154+
155+
@pytest.mark.parametrize("original, dtype", [
156+
(SparseDtype(float, np.nan), int),
157+
(SparseDtype(str, 'abc'), int),
158+
])
159+
def test_update_dtype_raises(original, dtype):
160+
with pytest.raises(ValueError):
161+
original.update_dtype(dtype)

0 commit comments

Comments
 (0)