forked from pandas-dev/pandas
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_register_accessor.py
89 lines (67 loc) · 2.21 KB
/
test_register_accessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import contextlib
import pytest
import pandas as pd
import pandas.util.testing as tm
@contextlib.contextmanager
def ensure_removed(obj, attr):
"""Ensure that an attribute added to 'obj' during the test is
removed when we're done"""
try:
yield
finally:
try:
delattr(obj, attr)
except AttributeError:
pass
obj._accessors.discard(attr)
class MyAccessor(object):
def __init__(self, obj):
self.obj = obj
self.item = 'item'
@property
def prop(self):
return self.item
def method(self):
return self.item
@pytest.mark.parametrize('obj, registrar', [
(pd.Series, pd.api.extensions.register_series_accessor),
(pd.DataFrame, pd.api.extensions.register_dataframe_accessor),
(pd.Index, pd.api.extensions.register_index_accessor)
])
def test_register(obj, registrar):
with ensure_removed(obj, 'mine'):
before = set(dir(obj))
registrar('mine')(MyAccessor)
assert obj([]).mine.prop == 'item'
after = set(dir(obj))
assert (before ^ after) == {'mine'}
assert 'mine' in obj._accessors
def test_accessor_works():
with ensure_removed(pd.Series, 'mine'):
pd.api.extensions.register_series_accessor('mine')(MyAccessor)
s = pd.Series([1, 2])
assert s.mine.obj is s
assert s.mine.prop == 'item'
assert s.mine.method() == 'item'
def test_overwrite_warns():
# Need to restore mean
mean = pd.Series.mean
try:
with tm.assert_produces_warning(UserWarning) as w:
pd.api.extensions.register_series_accessor('mean')(MyAccessor)
s = pd.Series([1, 2])
assert s.mean.prop == 'item'
msg = str(w[0].message)
assert 'mean' in msg
assert 'MyAccessor' in msg
assert 'Series' in msg
finally:
pd.Series.mean = mean
def test_raises_attribute_error():
with ensure_removed(pd.Series, 'bad'):
@pd.api.extensions.register_series_accessor("bad")
class Bad(object):
def __init__(self, data):
raise AttributeError("whoops")
with tm.assert_raises_regex(AttributeError, "whoops"):
pd.Series([]).bad