Skip to content

Commit a5c7a8a

Browse files
authored
[batch] Fix async exit stacks (hail-is#13969)
I couldn't find the best issue for this. Should fix hail-is#13908, but I thought there was another issue about reducing noisy grafana alerts which this PR also addresses.
1 parent d231b40 commit a5c7a8a

File tree

9 files changed

+80
-59
lines changed

9 files changed

+80
-59
lines changed

batch/batch/cloud/azure/driver/driver.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ async def create(
2828
machine_name_prefix: str,
2929
namespace: str,
3030
inst_coll_configs: InstanceCollectionConfigs,
31-
task_manager: aiotools.BackgroundTaskManager, # BORROWED
3231
) -> 'AzureDriver':
3332
azure_config = get_azure_config()
3433
subscription_id = azure_config.subscription_id
@@ -68,6 +67,8 @@ async def create(
6867
app, subscription_id, resource_group, ssh_public_key, arm_client, compute_client, billing_manager
6968
)
7069

70+
task_manager = aiotools.BackgroundTaskManager()
71+
7172
create_pools_coros = [
7273
Pool.create(
7374
app,
@@ -110,6 +111,7 @@ async def create(
110111
inst_coll_manager,
111112
jpim,
112113
billing_manager,
114+
task_manager,
113115
)
114116

115117
task_manager.ensure_future(periodically_call(60, driver.delete_orphaned_nics))
@@ -135,6 +137,7 @@ def __init__(
135137
inst_coll_manager: InstanceCollectionManager,
136138
job_private_inst_manager: JobPrivateInstanceManager,
137139
billing_manager: AzureBillingManager,
140+
task_manager: aiotools.BackgroundTaskManager,
138141
):
139142
self.db = db
140143
self.machine_name_prefix = machine_name_prefix
@@ -150,6 +153,7 @@ def __init__(
150153
self.job_private_inst_manager = job_private_inst_manager
151154
self._billing_manager = billing_manager
152155
self._inst_coll_manager = inst_coll_manager
156+
self._task_manager = task_manager
153157

154158
@property
155159
def billing_manager(self) -> AzureBillingManager:
@@ -161,18 +165,21 @@ def inst_coll_manager(self) -> InstanceCollectionManager:
161165

162166
async def shutdown(self) -> None:
163167
try:
164-
await self.arm_client.close()
168+
await self._task_manager.shutdown_and_wait()
165169
finally:
166170
try:
167-
await self.compute_client.close()
171+
await self.arm_client.close()
168172
finally:
169173
try:
170-
await self.resources_client.close()
174+
await self.compute_client.close()
171175
finally:
172176
try:
173-
await self.network_client.close()
177+
await self.resources_client.close()
174178
finally:
175-
await self.pricing_client.close()
179+
try:
180+
await self.network_client.close()
181+
finally:
182+
await self.pricing_client.close()
176183

177184
def _resource_is_orphaned(self, resource_name: str) -> bool:
178185
instance_name = resource_name.rsplit('-', maxsplit=1)[0]

batch/batch/cloud/driver.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from gear import Database
22
from gear.cloud_config import get_global_config
3-
from hailtop import aiotools
43

54
from ..driver.driver import CloudDriver
65
from ..inst_coll_config import InstanceCollectionConfigs
@@ -14,12 +13,11 @@ async def get_cloud_driver(
1413
machine_name_prefix: str,
1514
namespace: str,
1615
inst_coll_configs: InstanceCollectionConfigs,
17-
task_manager: aiotools.BackgroundTaskManager,
1816
) -> CloudDriver:
1917
cloud = get_global_config()['cloud']
2018

2119
if cloud == 'azure':
22-
return await AzureDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs, task_manager)
20+
return await AzureDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs)
2321

2422
assert cloud == 'gcp', cloud
25-
return await GCPDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs, task_manager)
23+
return await GCPDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs)

batch/batch/cloud/gcp/driver/driver.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ async def create(
2525
machine_name_prefix: str,
2626
namespace: str,
2727
inst_coll_configs: InstanceCollectionConfigs,
28-
task_manager: aiotools.BackgroundTaskManager, # BORROWED
2928
) -> 'GCPDriver':
3029
gcp_config = get_gcp_config()
3130
project = gcp_config.project
@@ -67,6 +66,8 @@ async def create(
6766
inst_coll_manager = InstanceCollectionManager(db, machine_name_prefix, zone_monitor, region, regions)
6867
resource_manager = GCPResourceManager(project, compute_client, billing_manager)
6968

69+
task_manager = aiotools.BackgroundTaskManager()
70+
7071
create_pools_coros = [
7172
Pool.create(
7273
app,
@@ -105,6 +106,7 @@ async def create(
105106
inst_coll_manager,
106107
jpim,
107108
billing_manager,
109+
task_manager,
108110
)
109111

110112
task_manager.ensure_future(periodically_call(15, driver.process_activity_logs))
@@ -126,6 +128,7 @@ def __init__(
126128
inst_coll_manager: InstanceCollectionManager,
127129
job_private_inst_manager: JobPrivateInstanceManager,
128130
billing_manager: GCPBillingManager,
131+
task_manager: aiotools.BackgroundTaskManager,
129132
):
130133
self.db = db
131134
self.machine_name_prefix = machine_name_prefix
@@ -137,6 +140,7 @@ def __init__(
137140
self.job_private_inst_manager = job_private_inst_manager
138141
self._billing_manager = billing_manager
139142
self._inst_coll_manager = inst_coll_manager
143+
self._task_manager = task_manager
140144

141145
@property
142146
def billing_manager(self) -> GCPBillingManager:
@@ -148,12 +152,15 @@ def inst_coll_manager(self) -> InstanceCollectionManager:
148152

149153
async def shutdown(self) -> None:
150154
try:
151-
await self.compute_client.close()
155+
await self._task_manager.shutdown_and_wait()
152156
finally:
153157
try:
154-
await self.activity_logs_client.close()
158+
await self.compute_client.close()
155159
finally:
156-
await self._billing_manager.close()
160+
try:
161+
await self.activity_logs_client.close()
162+
finally:
163+
await self._billing_manager.close()
157164

158165
async def process_activity_logs(self) -> None:
159166
async def _process_activity_log_events_since(mark):

batch/batch/driver/canceller.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def __init__(self, app):
6767

6868
self.task_manager = aiotools.BackgroundTaskManager()
6969

70-
def shutdown(self):
70+
async def shutdown_and_wait(self):
7171
try:
72-
self.task_manager.shutdown()
72+
await self.task_manager.shutdown_and_wait()
7373
finally:
74-
self.async_worker_pool.shutdown()
74+
await self.async_worker_pool.shutdown_and_wait()
7575

7676
async def cancel_cancelled_ready_jobs_loop_body(self):
7777
records = self.db.select_and_fetchall(

batch/batch/driver/main.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,18 +1558,25 @@ def log(self, request, response, time):
15581558

15591559

15601560
async def on_startup(app):
1561-
task_manager = aiotools.BackgroundTaskManager()
1562-
app['task_manager'] = task_manager
1563-
1564-
app['client_session'] = httpx.client_session()
1561+
exit_stack = AsyncExitStack()
1562+
app['exit_stack'] = exit_stack
15651563

15661564
kubernetes_asyncio.config.load_incluster_config()
15671565
app['k8s_client'] = kubernetes_asyncio.client.CoreV1Api()
15681566
app['k8s_cache'] = K8sCache(app['k8s_client'])
15691567

1568+
async def close_and_wait():
1569+
# - Following warning mitigation described here: https://github.com/aio-libs/aiohttp/pull/2045
1570+
# - Fixed in aiohttp 4.0.0: https://github.com/aio-libs/aiohttp/issues/1925
1571+
await app['k8s_client'].api_client.close()
1572+
await asyncio.sleep(0.250)
1573+
1574+
exit_stack.push_async_callback(close_and_wait)
1575+
15701576
db = Database()
15711577
await db.async_init(maxsize=50)
15721578
app['db'] = db
1579+
exit_stack.push_async_callback(app['db'].async_close)
15731580

15741581
row = await db.select_and_fetchone(
15751582
'''
@@ -1590,18 +1597,28 @@ async def on_startup(app):
15901597
app['cancel_ready_state_changed'] = asyncio.Event()
15911598
app['cancel_creating_state_changed'] = asyncio.Event()
15921599
app['cancel_running_state_changed'] = asyncio.Event()
1600+
15931601
app['async_worker_pool'] = AsyncWorkerPool(100, queue_size=100)
1602+
exit_stack.push_async_callback(app['async_worker_pool'].shutdown_and_wait)
15941603

15951604
fs = get_cloud_async_fs()
15961605
app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id)
1606+
exit_stack.push_async_callback(app['file_store'].close)
15971607

15981608
inst_coll_configs = await InstanceCollectionConfigs.create(db)
15991609

1600-
app['driver'] = await get_cloud_driver(
1601-
app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs, task_manager
1602-
)
1610+
app['client_session'] = httpx.client_session()
1611+
exit_stack.push_async_callback(app['client_session'].close)
1612+
1613+
app['driver'] = await get_cloud_driver(app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs)
1614+
exit_stack.push_async_callback(app['driver'].shutdown)
16031615

16041616
app['canceller'] = await Canceller.create(app)
1617+
exit_stack.push_async_callback(app['canceller'].shutdown_and_wait)
1618+
1619+
task_manager = aiotools.BackgroundTaskManager()
1620+
app['task_manager'] = task_manager
1621+
exit_stack.push_async_callback(app['task_manager'].shutdown_and_wait)
16051622

16061623
task_manager.ensure_future(periodically_call(10, monitor_billing_limits, app))
16071624
task_manager.ensure_future(periodically_call(10, cancel_fast_failing_batches, app))
@@ -1614,24 +1631,7 @@ async def on_startup(app):
16141631

16151632
async def on_cleanup(app):
16161633
try:
1617-
async with AsyncExitStack() as cleanup:
1618-
cleanup.callback(app['canceller'].shutdown)
1619-
cleanup.callback(app['task_manager'].shutdown)
1620-
cleanup.push_async_callback(app['driver'].shutdown)
1621-
cleanup.push_async_callback(app['file_store'].shutdown)
1622-
cleanup.push_async_callback(app['client_session'].close)
1623-
cleanup.callback(app['async_worker_pool'].shutdown)
1624-
cleanup.push_async_callback(app['db'].async_close)
1625-
1626-
k8s: kubernetes_asyncio.client.CoreV1Api = app['k8s_client']
1627-
1628-
async def close_and_wait():
1629-
# - Following warning mitigation described here: https://github.com/aio-libs/aiohttp/pull/2045
1630-
# - Fixed in aiohttp 4.0.0: https://github.com/aio-libs/aiohttp/issues/1925
1631-
await k8s.api_client.close()
1632-
await asyncio.sleep(0.250)
1633-
1634-
cleanup.push_async_callback(close_and_wait)
1634+
await app['exit_stack'].aclose()
16351635
finally:
16361636
await asyncio.gather(*(t for t in asyncio.all_tasks() if t is not asyncio.current_task()))
16371637

batch/batch/front_end/front_end.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2903,12 +2903,16 @@ def log(self, request, response, time):
29032903

29042904

29052905
async def on_startup(app):
2906-
app['task_manager'] = aiotools.BackgroundTaskManager()
2906+
exit_stack = AsyncExitStack()
2907+
app['exit_stack'] = exit_stack
2908+
29072909
app['client_session'] = httpx.client_session()
2910+
exit_stack.push_async_callback(app['client_session'].close)
29082911

29092912
db = Database()
29102913
await db.async_init()
29112914
app['db'] = db
2915+
exit_stack.push_async_callback(app['db'].async_close)
29122916

29132917
row = await db.select_and_fetchone(
29142918
'''
@@ -2923,6 +2927,7 @@ async def on_startup(app):
29232927
app['instance_id'] = instance_id
29242928

29252929
app['hail_credentials'] = hail_credentials()
2930+
exit_stack.push_async_callback(app['hail_credentials'].close)
29262931

29272932
app['frozen'] = row['frozen']
29282933

@@ -2937,8 +2942,13 @@ async def on_startup(app):
29372942

29382943
fs = get_cloud_async_fs()
29392944
app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id)
2945+
exit_stack.push_async_callback(app['file_store'].close)
2946+
2947+
app['task_manager'] = aiotools.BackgroundTaskManager()
2948+
exit_stack.callback(app['task_manager'].shutdown)
29402949

29412950
app['inst_coll_configs'] = await InstanceCollectionConfigs.create(db)
2951+
exit_stack.push_async_callback(app['file_store'].close)
29422952

29432953
cancel_batch_state_changed = asyncio.Event()
29442954
app['cancel_batch_state_changed'] = cancel_batch_state_changed
@@ -2958,12 +2968,7 @@ async def on_startup(app):
29582968

29592969

29602970
async def on_cleanup(app):
2961-
async with AsyncExitStack() as stack:
2962-
stack.callback(app['task_manager'].shutdown)
2963-
stack.push_async_callback(app['hail_credentials'].close)
2964-
stack.push_async_callback(app['client_session'].close)
2965-
stack.push_async_callback(app['file_store'].close)
2966-
stack.push_async_callback(app['db'].async_close)
2971+
await app['exit_stack'].aclose()
29672972

29682973

29692974
def run():

batch/batch/worker/worker.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3039,15 +3039,15 @@ async def shutdown(self):
30393039
log.info('Worker.shutdown')
30403040
self._jvm_initializer_task.cancel()
30413041
async with AsyncExitStack() as cleanup:
3042+
cleanup.push_async_callback(self.client_session.close)
3043+
if self.fs:
3044+
cleanup.push_async_callback(self.fs.close)
3045+
if self.file_store:
3046+
cleanup.push_async_callback(self.file_store.close)
30423047
for jvmqueue in self._jvmpools_by_cores.values():
30433048
while not jvmqueue.queue.empty():
30443049
cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill)
30453050
cleanup.push_async_callback(self.task_manager.shutdown_and_wait)
3046-
if self.file_store:
3047-
cleanup.push_async_callback(self.file_store.close)
3048-
if self.fs:
3049-
cleanup.push_async_callback(self.fs.close)
3050-
cleanup.push_async_callback(self.client_session.close)
30513051

30523052
async def run_job(self, job):
30533053
try:
@@ -3475,11 +3475,10 @@ async def async_main():
34753475
with aiomonitor.start_monitor(asyncio.get_event_loop(), locals=locals()):
34763476
try:
34773477
async with AsyncExitStack() as cleanup:
3478-
cleanup.push_async_callback(worker.shutdown)
3479-
cleanup.push_async_callback(CLOUD_WORKER_API.close)
3480-
cleanup.push_async_callback(network_allocator_task_manager.shutdown_and_wait)
34813478
cleanup.push_async_callback(docker.close)
3482-
3479+
cleanup.push_async_callback(network_allocator_task_manager.shutdown_and_wait)
3480+
cleanup.push_async_callback(CLOUD_WORKER_API.close)
3481+
cleanup.push_async_callback(worker.shutdown)
34833482
await worker.run()
34843483
finally:
34853484
asyncio.get_event_loop().set_debug(True)

hail/python/hailtop/aiotools/tasks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@ def shutdown(self):
4545

4646
async def shutdown_and_wait(self):
4747
self.shutdown()
48-
await asyncio.wait(self.tasks, return_when=asyncio.ALL_COMPLETED)
48+
if self.tasks:
49+
await asyncio.wait(self.tasks, return_when=asyncio.ALL_COMPLETED)

hail/python/hailtop/utils/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ def shutdown(self):
224224
except Exception:
225225
pass
226226

227+
async def shutdown_and_wait(self):
228+
self.shutdown()
229+
await asyncio.gather(*self.workers, return_exceptions=True)
230+
227231

228232
class WaitableSharedPool:
229233
def __init__(self, worker_pool: AsyncWorkerPool):

0 commit comments

Comments
 (0)