Skip to content

Commit a0aaad9

Browse files
committed
TST: more comprehensive dtype testing for rolling
1 parent 3d70be7 commit a0aaad9

File tree

2 files changed

+90
-59
lines changed

2 files changed

+90
-59
lines changed

pandas/core/window.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,17 @@ def _dir_additions(self):
124124
def _get_window(self, other=None):
125125
return self.window
126126

127+
@property
128+
def _window_type(self):
129+
return self.__class__.__name__
130+
127131
def __unicode__(self):
128132
""" provide a nice str repr of our rolling object """
129133

130134
attrs = ["{k}={v}".format(k=k, v=getattr(self, k))
131135
for k in self._attributes
132136
if getattr(self, k, None) is not None]
133-
return "{klass} [{attrs}]".format(klass=self.__class__.__name__,
137+
return "{klass} [{attrs}]".format(klass=self._window_type,
134138
attrs=','.join(attrs))
135139

136140
def _shallow_copy(self, obj=None, **kwargs):
@@ -155,8 +159,12 @@ def _prep_values(self, values=None, kill_inf=True, how=None):
155159
values = com._ensure_float64(values)
156160
elif com.is_integer_dtype(values.dtype):
157161
values = com._ensure_float64(values)
158-
elif com.is_timedelta64_dtype(values.dtype):
159-
values = com._ensure_float64(values.view('i8'))
162+
elif com.needs_i8_conversion(values.dtype):
163+
raise NotImplementedError("ops for {action} for this "
164+
"dtype {dtype} are not "
165+
"implemented".format(
166+
action=self._window_type,
167+
dtype=values.dtype))
160168
else:
161169
try:
162170
values = com._ensure_float64(values)
@@ -498,15 +506,25 @@ def count(self):
498506
window = self._get_window()
499507
window = min(window, len(obj)) if not self.center else window
500508

501-
try:
502-
converted = np.isfinite(obj).astype(float)
503-
except TypeError:
504-
converted = np.isfinite(obj.astype(float)).astype(float)
505-
result = self._constructor(converted, window=window, min_periods=0,
506-
center=self.center).sum()
509+
blocks, obj = self._create_blocks(how=None)
510+
results = []
511+
for b in blocks:
507512

508-
result[result.isnull()] = 0
509-
return result
513+
if com.needs_i8_conversion(b.values):
514+
result = b.notnull().astype(int)
515+
else:
516+
try:
517+
result = np.isfinite(b).astype(float)
518+
except TypeError:
519+
result = np.isfinite(b.astype(float)).astype(float)
520+
521+
result[pd.isnull(result)] = 0
522+
523+
result = self._constructor(result, window=window, min_periods=0,
524+
center=self.center).sum()
525+
results.append(result)
526+
527+
return self._wrap_results(results, blocks, obj)
510528

511529
_shared_docs['apply'] = dedent("""
512530
%(name)s function apply

pandas/tests/test_window.py

+61-48
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import warnings
55

6+
from nose.tools import assert_raises
67
from datetime import datetime
78
from numpy.random import randn
89
from numpy.testing.decorators import slow
@@ -98,19 +99,6 @@ def tests_skip_nuisance(self):
9899
result = r.sum()
99100
assert_frame_equal(result, expected)
100101

101-
def test_timedeltas(self):
102-
103-
df = DataFrame({'A': range(5),
104-
'B': pd.timedelta_range('1 day', periods=5)})
105-
r = df.rolling(window=3)
106-
result = r.sum()
107-
expected = DataFrame({'A': [np.nan, np.nan, 3, 6, 9],
108-
'B': pd.to_timedelta([pd.NaT, pd.NaT,
109-
'6 days', '9 days',
110-
'12 days'])},
111-
columns=list('AB'))
112-
assert_frame_equal(result, expected)
113-
114102
def test_agg(self):
115103
df = DataFrame({'A': range(5), 'B': range(0, 10, 2)})
116104

@@ -291,8 +279,13 @@ def test_deprecations(self):
291279

292280
# GH #12373 : rolling functions error on float32 data
293281
# make sure rolling functions works for different dtypes
294-
class TestDtype(Base):
295-
dtype = None
282+
#
283+
# NOTE that these are yielded tests and so _create_data is
284+
# explicity called, nor do these inherit from unittest.TestCase
285+
#
286+
# further note that we are only checking rolling for fully dtype
287+
# compliance (though both expanding and ewm inherit)
288+
class Dtype(object):
296289
window = 2
297290

298291
funcs = {
@@ -371,76 +364,84 @@ def _create_dtype_data(self, dtype):
371364
return data
372365

373366
def _create_data(self):
374-
super(TestDtype, self)._create_data()
375367
self.data = self._create_dtype_data(self.dtype)
376368
self.expects = self.get_expects()
377369

378-
def setUp(self):
379-
self._create_data()
380-
381370
def test_dtypes(self):
371+
self._create_data()
382372
for f_name, d_name in product(self.funcs.keys(), self.data.keys()):
383373
f = self.funcs[f_name]
384374
d = self.data[d_name]
385-
assert_equal = assert_series_equal if isinstance(
386-
d, Series) else assert_frame_equal
387375
exp = self.expects[d_name][f_name]
376+
yield self.check_dtypes, f, f_name, d, d_name, exp
388377

389-
roll = d.rolling(window=self.window)
390-
result = f(roll)
391-
392-
assert_equal(result, exp)
378+
def check_dtypes(self, f, f_name, d, d_name, exp):
379+
roll = d.rolling(window=self.window)
380+
result = f(roll)
381+
assert_almost_equal(result, exp)
393382

394383

395-
class TestDtype_object(TestDtype):
384+
class TestDtype_object(Dtype):
396385
dtype = object
397386

398387

399-
class TestDtype_int8(TestDtype):
388+
class Dtype_integer(Dtype):
389+
pass
390+
391+
392+
class TestDtype_int8(Dtype_integer):
400393
dtype = np.int8
401394

402395

403-
class TestDtype_int16(TestDtype):
396+
class TestDtype_int16(Dtype_integer):
404397
dtype = np.int16
405398

406399

407-
class TestDtype_int32(TestDtype):
400+
class TestDtype_int32(Dtype_integer):
408401
dtype = np.int32
409402

410403

411-
class TestDtype_int64(TestDtype):
404+
class TestDtype_int64(Dtype_integer):
412405
dtype = np.int64
413406

414407

415-
class TestDtype_uint8(TestDtype):
408+
class Dtype_uinteger(Dtype):
409+
pass
410+
411+
412+
class TestDtype_uint8(Dtype_uinteger):
416413
dtype = np.uint8
417414

418415

419-
class TestDtype_uint16(TestDtype):
416+
class TestDtype_uint16(Dtype_uinteger):
420417
dtype = np.uint16
421418

422419

423-
class TestDtype_uint32(TestDtype):
420+
class TestDtype_uint32(Dtype_uinteger):
424421
dtype = np.uint32
425422

426423

427-
class TestDtype_uint64(TestDtype):
424+
class TestDtype_uint64(Dtype_uinteger):
428425
dtype = np.uint64
429426

430427

431-
class TestDtype_float16(TestDtype):
428+
class Dtype_float(Dtype):
429+
pass
430+
431+
432+
class TestDtype_float16(Dtype_float):
432433
dtype = np.float16
433434

434435

435-
class TestDtype_float32(TestDtype):
436+
class TestDtype_float32(Dtype_float):
436437
dtype = np.float32
437438

438439

439-
class TestDtype_float64(TestDtype):
440+
class TestDtype_float64(Dtype_float):
440441
dtype = np.float64
441442

442443

443-
class TestDtype_category(TestDtype):
444+
class TestDtype_category(Dtype):
444445
dtype = 'category'
445446
include_df = False
446447

@@ -456,25 +457,37 @@ def _create_dtype_data(self, dtype):
456457
return data
457458

458459

459-
class TestDatetimeLikeDtype(TestDtype):
460-
dtype = np.dtype('M8[ns]')
460+
class DatetimeLike(Dtype):
461461

462-
# GH #12373: rolling functions raise ValueError on float32 data
463-
def setUp(self):
464-
raise nose.SkipTest("Skip rolling on DatetimeLike dtypes [{0}].".format(self.dtype))
462+
def check_dtypes(self, f, f_name, d, d_name, exp):
465463

466-
def test_dtypes(self):
467-
with tm.assertRaises(TypeError):
468-
super(TestDatetimeLikeDtype, self).test_dtypes()
464+
roll = d.rolling(window=self.window)
465+
466+
if f_name == 'count':
467+
result = f(roll)
468+
assert_almost_equal(result, exp)
469469

470+
else:
471+
472+
# other methods not Implemented ATM
473+
assert_raises(NotImplementedError, f, roll)
470474

471-
class TestDtype_timedelta(TestDatetimeLikeDtype):
475+
476+
class TestDtype_timedelta(DatetimeLike):
472477
dtype = np.dtype('m8[ns]')
473478

474479

475-
class TestDtype_datetime64UTC(TestDatetimeLikeDtype):
480+
class TestDtype_datetime(DatetimeLike):
481+
dtype = np.dtype('M8[ns]')
482+
483+
484+
class TestDtype_datetime64UTC(DatetimeLike):
476485
dtype = 'datetime64[ns, UTC]'
477486

487+
def _create_data(self):
488+
raise nose.SkipTest("direct creation of extension dtype "
489+
"datetime64[ns, UTC] is not supported ATM")
490+
478491

479492
class TestMoments(Base):
480493

0 commit comments

Comments
 (0)