Skip to content

Commit a7f7e1d

Browse files
sstanovnikjreback
authored andcommitted
BUG: Fix slicing subclasses of SparseDataFrames.
Use proper subclassing behaviour so subclasses work properly: this fixes an issue where a multi-element slice of a subclass of SparseDataFrame returned the SparseDataFrame type instead of the subclass type. closes pandas-dev#13787
1 parent 1f55e91 commit a7f7e1d

File tree

7 files changed

+151
-21
lines changed

7 files changed

+151
-21
lines changed

doc/source/whatsnew/v0.19.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ API changes
380380
- ``pd.Timedelta(None)`` is now accepted and will return ``NaT``, mirroring ``pd.Timestamp`` (:issue:`13687`)
381381
- ``Timestamp``, ``Period``, ``DatetimeIndex``, ``PeriodIndex`` and ``.dt`` accessor have gained a ``.is_leap_year`` property to check whether the date belongs to a leap year. (:issue:`13727`)
382382
- ``pd.read_hdf`` will now raise a ``ValueError`` instead of ``KeyError``, if a mode other than ``r``, ``r+`` and ``a`` is supplied. (:issue:`13623`)
383+
- Subclassed ``SparseDataFrame`` and ``SparseSeries`` now preserve class types when slicing or transposing. (:issue:`13787`)
384+
383385

384386

385387
.. _whatsnew_0190.api.tolist:

pandas/io/tests/test_pickle.py

+8
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ def compare(self, vf, version):
8686
comparator(result, expected, typ, version)
8787
return data
8888

89+
def compare_sp_series_ts(self, res, exp, typ, version):
90+
# SparseTimeSeries integrated into SparseSeries in 0.12.0
91+
# and deprecated in 0.17.0
92+
if version and LooseVersion(version) <= "0.12.0":
93+
tm.assert_sp_series_equal(res, exp, check_series_type=False)
94+
else:
95+
tm.assert_sp_series_equal(res, exp)
96+
8997
def compare_series_ts(self, result, expected, typ, version):
9098
# GH 7748
9199
tm.assert_series_equal(result, expected)

pandas/sparse/frame.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _init_matrix(self, data, index, columns, dtype=None):
188188
return self._init_dict(data, index, columns, dtype)
189189

190190
def __array_wrap__(self, result):
191-
return SparseDataFrame(
191+
return self._constructor(
192192
result, index=self.index, columns=self.columns,
193193
default_kind=self._default_kind,
194194
default_fill_value=self._default_fill_value).__finalize__(self)
@@ -407,7 +407,7 @@ def _combine_frame(self, other, func, fill_value=None, level=None):
407407
raise NotImplementedError("'level' argument is not supported")
408408

409409
if self.empty and other.empty:
410-
return SparseDataFrame(index=new_index).__finalize__(self)
410+
return self._constructor(index=new_index).__finalize__(self)
411411

412412
new_data = {}
413413
new_fill_value = None
@@ -519,7 +519,8 @@ def _reindex_index(self, index, method, copy, level, fill_value=np.nan,
519519
return self
520520

521521
if len(self.index) == 0:
522-
return SparseDataFrame(index=index, columns=self.columns)
522+
return self._constructor(
523+
index=index, columns=self.columns).__finalize__(self)
523524

524525
indexer = self.index.get_indexer(index, method, limit=limit)
525526
indexer = _ensure_platform_int(indexer)
@@ -540,8 +541,9 @@ def _reindex_index(self, index, method, copy, level, fill_value=np.nan,
540541

541542
new_series[col] = new
542543

543-
return SparseDataFrame(new_series, index=index, columns=self.columns,
544-
default_fill_value=self._default_fill_value)
544+
return self._constructor(
545+
new_series, index=index, columns=self.columns,
546+
default_fill_value=self._default_fill_value).__finalize__(self)
545547

546548
def _reindex_columns(self, columns, copy, level, fill_value, limit=None,
547549
takeable=False):
@@ -556,8 +558,9 @@ def _reindex_columns(self, columns, copy, level, fill_value, limit=None,
556558

557559
# TODO: fill value handling
558560
sdict = dict((k, v) for k, v in compat.iteritems(self) if k in columns)
559-
return SparseDataFrame(sdict, index=self.index, columns=columns,
560-
default_fill_value=self._default_fill_value)
561+
return self._constructor(
562+
sdict, index=self.index, columns=columns,
563+
default_fill_value=self._default_fill_value).__finalize__(self)
561564

562565
def _reindex_with_indexers(self, reindexers, method=None, fill_value=None,
563566
limit=None, copy=False, allow_dups=False):
@@ -586,8 +589,8 @@ def _reindex_with_indexers(self, reindexers, method=None, fill_value=None,
586589
else:
587590
new_arrays[col] = self[col]
588591

589-
return SparseDataFrame(new_arrays, index=index,
590-
columns=columns).__finalize__(self)
592+
return self._constructor(new_arrays, index=index,
593+
columns=columns).__finalize__(self)
591594

592595
def _join_compat(self, other, on=None, how='left', lsuffix='', rsuffix='',
593596
sort=False):
@@ -644,7 +647,7 @@ def transpose(self, *args, **kwargs):
644647
Returns a DataFrame with the rows/columns switched.
645648
"""
646649
nv.validate_transpose(args, kwargs)
647-
return SparseDataFrame(
650+
return self._constructor(
648651
self.values.T, index=self.columns, columns=self.index,
649652
default_fill_value=self._default_fill_value,
650653
default_kind=self._default_kind).__finalize__(self)

pandas/sparse/series.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ def wrapper(self, other):
6363
new_fill_value = op(np.float64(self.fill_value),
6464
np.float64(other))
6565

66-
return SparseSeries(op(self.sp_values, other),
67-
index=self.index,
68-
sparse_index=self.sp_index,
69-
fill_value=new_fill_value,
70-
name=self.name)
66+
return self._constructor(op(self.sp_values, other),
67+
index=self.index,
68+
sparse_index=self.sp_index,
69+
fill_value=new_fill_value,
70+
name=self.name)
7171
else: # pragma: no cover
7272
raise TypeError('operation with %s not supported' % type(other))
7373

@@ -85,7 +85,7 @@ def _sparse_series_op(left, right, op, name):
8585
new_name = _maybe_match_name(left, right)
8686

8787
result = _sparse_array_op(left, right, op, name)
88-
return SparseSeries(result, index=new_index, name=new_name)
88+
return left._constructor(result, index=new_index, name=new_name)
8989

9090

9191
class SparseSeries(Series):

pandas/tests/frame/test_subclass.py

+30
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,33 @@ def test_subclass_align_combinations(self):
210210
tm.assert_series_equal(res1, exp2)
211211
tm.assertIsInstance(res2, tm.SubclassedDataFrame)
212212
tm.assert_frame_equal(res2, exp1)
213+
214+
def test_subclass_sparse_slice(self):
215+
rows = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
216+
ssdf = tm.SubclassedSparseDataFrame(rows)
217+
ssdf.testattr = "testattr"
218+
219+
tm.assert_sp_frame_equal(ssdf.loc[:2],
220+
tm.SubclassedSparseDataFrame(rows[:3]))
221+
tm.assert_sp_frame_equal(ssdf.iloc[:2],
222+
tm.SubclassedSparseDataFrame(rows[:2]))
223+
tm.assert_sp_frame_equal(ssdf[:2],
224+
tm.SubclassedSparseDataFrame(rows[:2]))
225+
tm.assert_equal(ssdf.loc[:2].testattr, "testattr")
226+
tm.assert_equal(ssdf.iloc[:2].testattr, "testattr")
227+
tm.assert_equal(ssdf[:2].testattr, "testattr")
228+
229+
tm.assert_sp_series_equal(ssdf.loc[1],
230+
tm.SubclassedSparseSeries(rows[1]),
231+
check_names=False)
232+
tm.assert_sp_series_equal(ssdf.iloc[1],
233+
tm.SubclassedSparseSeries(rows[1]),
234+
check_names=False)
235+
236+
def test_subclass_sparse_transpose(self):
237+
ossdf = tm.SubclassedSparseDataFrame([[1, 2, 3],
238+
[4, 5, 6]])
239+
essdf = tm.SubclassedSparseDataFrame([[1, 4],
240+
[2, 5],
241+
[3, 6]])
242+
tm.assert_sp_frame_equal(ossdf.T, essdf)

pandas/tests/series/test_subclass.py

+24
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,27 @@ def test_to_frame(self):
3131
exp = tm.SubclassedDataFrame({'xxx': [1, 2, 3, 4]}, index=list('abcd'))
3232
tm.assert_frame_equal(res, exp)
3333
tm.assertIsInstance(res, tm.SubclassedDataFrame)
34+
35+
def test_subclass_sparse_slice(self):
36+
s = tm.SubclassedSparseSeries([1, 2, 3, 4, 5])
37+
tm.assert_sp_series_equal(s.loc[1:3],
38+
tm.SubclassedSparseSeries([2.0, 3.0, 4.0],
39+
index=[1, 2, 3]))
40+
tm.assert_sp_series_equal(s.iloc[1:3],
41+
tm.SubclassedSparseSeries([2.0, 3.0],
42+
index=[1, 2]))
43+
tm.assert_sp_series_equal(s[1:3],
44+
tm.SubclassedSparseSeries([2.0, 3.0],
45+
index=[1, 2]))
46+
47+
def test_subclass_sparse_addition(self):
48+
s1 = tm.SubclassedSparseSeries([1, 3, 5])
49+
s2 = tm.SubclassedSparseSeries([-2, 5, 12])
50+
tm.assert_sp_series_equal(s1 + s2,
51+
tm.SubclassedSparseSeries([-1.0, 8.0, 17.0]))
52+
53+
def test_subclass_sparse_to_frame(self):
54+
s = tm.SubclassedSparseSeries([1, 2], index=list('abcd'), name='xxx')
55+
res = s.to_frame()
56+
exp = tm.SubclassedSparseDataFrame({'xxx': [1, 2]}, index=list('abcd'))
57+
tm.assert_sp_frame_equal(res, exp)

pandas/util/testing.py

+68-5
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,8 @@ def assert_panelnd_equal(left, right,
13221322
check_less_precise=False,
13231323
assert_func=assert_frame_equal,
13241324
check_names=False,
1325-
by_blocks=False):
1325+
by_blocks=False,
1326+
obj='Panel'):
13261327
"""Check that left and right Panels are equal.
13271328
13281329
Parameters
@@ -1343,6 +1344,9 @@ def assert_panelnd_equal(left, right,
13431344
by_blocks : bool, default False
13441345
Specify how to compare internal data. If False, compare by columns.
13451346
If True, compare by blocks.
1347+
obj : str, default 'Panel'
1348+
Specify the object name being compared, internally used to show
1349+
the appropriate assertion message.
13461350
"""
13471351

13481352
if check_panel_type:
@@ -1404,10 +1408,30 @@ def assert_sp_array_equal(left, right):
14041408

14051409

14061410
def assert_sp_series_equal(left, right, exact_indices=True,
1407-
check_names=True, obj='SparseSeries'):
1411+
check_series_type=True,
1412+
check_names=True,
1413+
obj='SparseSeries'):
1414+
"""Check that the left and right SparseSeries are equal.
1415+
1416+
Parameters
1417+
----------
1418+
left : SparseSeries
1419+
right : SparseSeries
1420+
exact_indices : bool, default True
1421+
check_series_type : bool, default True
1422+
Whether to check the SparseSeries class is identical.
1423+
check_names : bool, default True
1424+
Whether to check the SparseSeries name attribute.
1425+
obj : str, default 'SparseSeries'
1426+
Specify the object name being compared, internally used to show
1427+
the appropriate assertion message.
1428+
"""
14081429
assertIsInstance(left, pd.SparseSeries, '[SparseSeries]')
14091430
assertIsInstance(right, pd.SparseSeries, '[SparseSeries]')
14101431

1432+
if check_series_type:
1433+
assert_class_equal(left, right, obj=obj)
1434+
14111435
assert_index_equal(left.index, right.index,
14121436
obj='{0}.index'.format(obj))
14131437

@@ -1421,14 +1445,29 @@ def assert_sp_series_equal(left, right, exact_indices=True,
14211445

14221446

14231447
def assert_sp_frame_equal(left, right, exact_indices=True,
1448+
check_frame_type=True,
14241449
obj='SparseDataFrame'):
1425-
"""
1426-
exact: Series SparseIndex objects must be exactly the same, otherwise just
1427-
compare dense representations
1450+
"""Check that the left and right SparseDataFrame are equal.
1451+
1452+
Parameters
1453+
----------
1454+
left : SparseDataFrame
1455+
right : SparseDataFrame
1456+
exact_indices : bool, default True
1457+
SparseSeries SparseIndex objects must be exactly the same,
1458+
otherwise just compare dense representations.
1459+
check_frame_type : bool, default True
1460+
Whether to check the SparseDataFrame class is identical.
1461+
obj : str, default 'SparseDataFrame'
1462+
Specify the object name being compared, internally used to show
1463+
the appropriate assertion message.
14281464
"""
14291465
assertIsInstance(left, pd.SparseDataFrame, '[SparseDataFrame]')
14301466
assertIsInstance(right, pd.SparseDataFrame, '[SparseDataFrame]')
14311467

1468+
if check_frame_type:
1469+
assert_class_equal(left, right, obj=obj)
1470+
14321471
assert_index_equal(left.index, right.index,
14331472
obj='{0}.index'.format(obj))
14341473
assert_index_equal(left.columns, right.columns,
@@ -2607,6 +2646,30 @@ def _constructor_sliced(self):
26072646
return SubclassedSeries
26082647

26092648

2649+
class SubclassedSparseSeries(pd.SparseSeries):
2650+
_metadata = ['testattr']
2651+
2652+
@property
2653+
def _constructor(self):
2654+
return SubclassedSparseSeries
2655+
2656+
@property
2657+
def _constructor_expanddim(self):
2658+
return SubclassedSparseDataFrame
2659+
2660+
2661+
class SubclassedSparseDataFrame(pd.SparseDataFrame):
2662+
_metadata = ['testattr']
2663+
2664+
@property
2665+
def _constructor(self):
2666+
return SubclassedSparseDataFrame
2667+
2668+
@property
2669+
def _constructor_sliced(self):
2670+
return SubclassedSparseSeries
2671+
2672+
26102673
@contextmanager
26112674
def patch(ob, attr, value):
26122675
"""Temporarily patch an attribute of an object.

0 commit comments

Comments
 (0)