We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e86cbb0 commit 31269deCopy full SHA for 31269de
reptile_gen/batching.py
@@ -18,7 +18,7 @@ def run_grad_fn(inputs, outputs):
18
if device != 'cpu':
19
model.to(d)
20
res = grad_fn(model, inputs.to(d), outputs.to(d))
21
- return [p.grad for p in model.parameters()], res
+ return [p.grad.cpu() for p in model.parameters()], res
22
23
pool = Pool(min(len(batch), threads))
24
pickled_fn = cloudpickle.dumps(run_grad_fn)
0 commit comments