Skip to content

Commit ced8628

Browse files
committed
TST: add tests for call_once
1 parent dc81075 commit ced8628

File tree

2 files changed

+86
-3
lines changed

2 files changed

+86
-3
lines changed

pandas/tests/util/test_call_once.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import time
2+
from unittest.mock import MagicMock
3+
4+
from pandas.util._call_once import call_once
5+
6+
7+
def test_basic():
8+
callback1, callback2 = MagicMock(), MagicMock()
9+
with call_once(callback1):
10+
with call_once(callback1):
11+
with call_once(callback2):
12+
with call_once(callback1):
13+
with call_once(callback2):
14+
pass
15+
16+
assert callback1.call_count == 1
17+
assert callback2.call_count == 1
18+
19+
20+
def test_with_key():
21+
cb1, cb2 = MagicMock(), MagicMock()
22+
with call_once(cb1, key="callback"):
23+
with call_once(cb2, key="callback"):
24+
pass
25+
26+
assert cb1.call_count == 1
27+
assert cb2.call_count == 0
28+
29+
30+
def test_across_stack_frames():
31+
callback = MagicMock()
32+
33+
def f():
34+
with call_once(callback):
35+
pass
36+
37+
def g():
38+
with call_once(callback):
39+
f()
40+
41+
f()
42+
assert callback.call_count == 1
43+
callback.reset_mock()
44+
45+
g()
46+
assert callback.call_count == 1
47+
48+
49+
def test_concurrent_threading():
50+
import threading
51+
52+
sleep_time = 0.01
53+
callback = MagicMock()
54+
55+
def run(initial_sleep=0):
56+
time.sleep(initial_sleep)
57+
with call_once(callback):
58+
with call_once(callback):
59+
time.sleep(2 * sleep_time)
60+
61+
thread1 = threading.Thread(target=run)
62+
thread2 = threading.Thread(target=run, kwargs={"initial_sleep": sleep_time})
63+
thread2.start()
64+
thread1.start()
65+
thread1.join()
66+
thread2.join()
67+
assert callback.call_count == 2
68+
69+
70+
def test_concurrent_asyncio():
71+
import asyncio
72+
73+
callback = MagicMock()
74+
75+
async def task():
76+
with call_once(callback):
77+
with call_once(callback):
78+
await asyncio.sleep(0.01)
79+
80+
async def main():
81+
await asyncio.gather(task(), task())
82+
83+
asyncio.run(main())
84+
assert callback.call_count == 2

pandas/util/_call_once.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,16 @@ class _ContextMapping(MutableMapping):
2121
_NO_DEFAULT_VALUE = object()
2222

2323
def __init__(self, default_value=_NO_DEFAULT_VALUE):
24-
d: dict = (
24+
initial_dict: dict = (
2525
defaultdict(lambda: default_value)
2626
if default_value is not self._NO_DEFAULT_VALUE
2727
else dict()
2828
)
2929
# yes, we're creating a contextvar inside a closure, but it doesn't matter
3030
# because objects of this class will only be created at module level
3131
self._dict_var: contextvars.ContextVar[dict] = contextvars.ContextVar(
32-
"_ContextMapping<{}>._dict_var".format(id(self))
32+
"_ContextMapping<{}>._dict_var".format(id(self)), default=initial_dict
3333
)
34-
self._dict_var.set(d)
3534

3635
def __setitem__(self, k, v) -> None:
3736
d = self._dict_var.get().copy()

0 commit comments

Comments
 (0)