@@ -284,6 +284,83 @@ def is_dtype(cls, dtype):
284
284
return True
285
285
return isinstance (dtype , np .dtype ) or dtype == 'Sparse'
286
286
287
+ def update_dtype (self , dtype ):
288
+ """Convert the SparseDtype to a new dtype.
289
+
290
+ This takes care of converting the ``fill_value``.
291
+
292
+ Parameters
293
+ ----------
294
+ dtype : Union[str, numpy.dtype, SparseDtype]
295
+ The new dtype to use.
296
+
297
+ * For a SparseDtype, it is simply returned
298
+ * For a NumPy dtype (or str), the current fill value
299
+ is converted to the new dtype, and a SparseDtype
300
+ with `dtype` and the new fill value is returned.
301
+
302
+ Returns
303
+ -------
304
+ SparseDtype
305
+ A new SparseDtype with the corret `dtype` and fill value
306
+ for that `dtype`.
307
+
308
+ Raises
309
+ ------
310
+ ValueError
311
+ When the current fill value cannot be converted to the
312
+ new `dtype` (e.g. trying to convert ``np.nan`` to an
313
+ integer dtype).
314
+
315
+
316
+ Examples
317
+ --------
318
+ >>> SparseDtype(int, 0).update_dtype(float)
319
+ Sparse[float64, 0.0]
320
+
321
+ >>> SparseDtype(int, 1).update_dtype(SparseDtype(float, np.nan))
322
+ Sparse[float64, nan]
323
+ """
324
+ cls = type (self )
325
+ dtype = pandas_dtype (dtype )
326
+
327
+ if not isinstance (dtype , cls ):
328
+ fill_value = astype_nansafe (np .array (self .fill_value ),
329
+ dtype ).item ()
330
+ dtype = cls (dtype , fill_value = fill_value )
331
+
332
+ return dtype
333
+
334
+ @property
335
+ def _subtype_with_str (self ):
336
+ """
337
+ Whether the SparseDtype's subtype should be considered ``str``.
338
+
339
+ Typically, pandas will store string data in an object-dtype array.
340
+ When converting values to a dtype, e.g. in ``.astype``, we need to
341
+ be more specific, we need the actual underlying type.
342
+
343
+ Returns
344
+ -------
345
+
346
+ >>> SparseDtype(int, 1)._subtype_with_str
347
+ dtype('int64')
348
+
349
+ >>> SparseDtype(object, 1)._subtype_with_str
350
+ dtype('O')
351
+
352
+ >>> dtype = SparseDtype(str, '')
353
+ >>> dtype.subtype
354
+ dtype('O')
355
+
356
+ >>> dtype._subtype_with_str
357
+ str
358
+ """
359
+ if isinstance (self .fill_value , compat .string_types ):
360
+ return type (self .fill_value )
361
+ return self .subtype
362
+
363
+
287
364
# ----------------------------------------------------------------------------
288
365
# Array
289
366
@@ -614,7 +691,7 @@ def __array__(self, dtype=None, copy=True):
614
691
# Can't put pd.NaT in a datetime64[ns]
615
692
fill_value = np .datetime64 ('NaT' )
616
693
try :
617
- dtype = np .result_type (self .sp_values .dtype , fill_value )
694
+ dtype = np .result_type (self .sp_values .dtype , type ( fill_value ) )
618
695
except TypeError :
619
696
dtype = object
620
697
@@ -996,7 +1073,7 @@ def _take_with_fill(self, indices, fill_value=None):
996
1073
if len (self ) == 0 :
997
1074
# Empty... Allow taking only if all empty
998
1075
if (indices == - 1 ).all ():
999
- dtype = np .result_type (self .sp_values , fill_value )
1076
+ dtype = np .result_type (self .sp_values , type ( fill_value ) )
1000
1077
taken = np .empty_like (indices , dtype = dtype )
1001
1078
taken .fill (fill_value )
1002
1079
return taken
@@ -1009,7 +1086,7 @@ def _take_with_fill(self, indices, fill_value=None):
1009
1086
if self .sp_index .npoints == 0 :
1010
1087
# Avoid taking from the empty self.sp_values
1011
1088
taken = np .full (sp_indexer .shape , fill_value = fill_value ,
1012
- dtype = np .result_type (fill_value ))
1089
+ dtype = np .result_type (type ( fill_value ) ))
1013
1090
else :
1014
1091
taken = self .sp_values .take (sp_indexer )
1015
1092
@@ -1030,12 +1107,13 @@ def _take_with_fill(self, indices, fill_value=None):
1030
1107
result_type = taken .dtype
1031
1108
1032
1109
if m0 .any ():
1033
- result_type = np .result_type (result_type , self .fill_value )
1110
+ result_type = np .result_type (result_type ,
1111
+ type (self .fill_value ))
1034
1112
taken = taken .astype (result_type )
1035
1113
taken [old_fill_indices ] = self .fill_value
1036
1114
1037
1115
if m1 .any ():
1038
- result_type = np .result_type (result_type , fill_value )
1116
+ result_type = np .result_type (result_type , type ( fill_value ) )
1039
1117
taken = taken .astype (result_type )
1040
1118
taken [new_fill_indices ] = fill_value
1041
1119
@@ -1061,7 +1139,7 @@ def _take_without_fill(self, indices):
1061
1139
# edge case in take...
1062
1140
# I think just return
1063
1141
out = np .full (indices .shape , self .fill_value ,
1064
- dtype = np .result_type (self .fill_value ))
1142
+ dtype = np .result_type (type ( self .fill_value ) ))
1065
1143
arr , sp_index , fill_value = make_sparse (out ,
1066
1144
fill_value = self .fill_value )
1067
1145
return type (self )(arr , sparse_index = sp_index ,
@@ -1073,7 +1151,7 @@ def _take_without_fill(self, indices):
1073
1151
1074
1152
if fillable .any ():
1075
1153
# TODO: may need to coerce array to fill value
1076
- result_type = np .result_type (taken , self .fill_value )
1154
+ result_type = np .result_type (taken , type ( self .fill_value ) )
1077
1155
taken = taken .astype (result_type )
1078
1156
taken [fillable ] = self .fill_value
1079
1157
@@ -1093,7 +1171,9 @@ def _concat_same_type(cls, to_concat):
1093
1171
1094
1172
fill_value = fill_values [0 ]
1095
1173
1096
- if len (set (fill_values )) > 1 :
1174
+ # np.nan isn't a singleton, so we may end up with multiple
1175
+ # NaNs here, so we ignore tha all NA case too.
1176
+ if not (len (set (fill_values )) == 1 or isna (fill_values ).all ()):
1097
1177
warnings .warn ("Concatenating sparse arrays with multiple fill "
1098
1178
"values: '{}'. Picking the first and "
1099
1179
"converting the rest." .format (fill_values ),
@@ -1212,13 +1292,10 @@ def astype(self, dtype=None, copy=True):
1212
1292
IntIndex
1213
1293
Indices: array([2, 3], dtype=int32)
1214
1294
"""
1215
- dtype = pandas_dtype (dtype )
1216
-
1217
- if not isinstance (dtype , SparseDtype ):
1218
- dtype = SparseDtype (dtype , fill_value = self .fill_value )
1219
-
1295
+ dtype = self .dtype .update_dtype (dtype )
1296
+ subtype = dtype ._subtype_with_str
1220
1297
sp_values = astype_nansafe (self .sp_values ,
1221
- dtype . subtype ,
1298
+ subtype ,
1222
1299
copy = copy )
1223
1300
if sp_values is self .sp_values and copy :
1224
1301
sp_values = sp_values .copy ()
0 commit comments