Skip to content

Commit da1c091

Browse files
committed
prevent semaphore leak
Thanks to pytorch/pytorch#3492 (comment)
1 parent 31269de commit da1c091

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

reptile_gen/batching.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from multiprocessing import Pool, set_start_method
1+
import multiprocessing
22

33
import cloudpickle
44
import torch
55

66

77
def batched_grad(model, grad_fn, batch, threads=1, device='cpu'):
8-
set_start_method('spawn', force=True)
9-
108
model_class = model.__class__
119
model_dict = {x: y.cpu().numpy() for x, y in model.state_dict().items()}
1210

@@ -20,11 +18,11 @@ def run_grad_fn(inputs, outputs):
2018
res = grad_fn(model, inputs.to(d), outputs.to(d))
2119
return [p.grad.cpu() for p in model.parameters()], res
2220

23-
pool = Pool(min(len(batch), threads))
2421
pickled_fn = cloudpickle.dumps(run_grad_fn)
25-
raw_results = pool.map(call_pickled_fn, [(pickled_fn, x) for x in batch])
22+
ctx = multiprocessing.get_context('spawn')
23+
with ctx.Pool(min(len(batch), threads)) as pool:
24+
raw_results = pool.map(call_pickled_fn, [(pickled_fn, x) for x in batch])
2625
grads, results = list(zip(*raw_results))
27-
pool.close()
2826

2927
for grad in grads:
3028
for p, g in zip(model.parameters(), grad):

0 commit comments

Comments
 (0)