Skip to content

Commit 31269de

Browse files
committed
fix grad on cuda
1 parent e86cbb0 commit 31269de

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

reptile_gen/batching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def run_grad_fn(inputs, outputs):
1818
if device != 'cpu':
1919
model.to(d)
2020
res = grad_fn(model, inputs.to(d), outputs.to(d))
21-
return [p.grad for p in model.parameters()], res
21+
return [p.grad.cpu() for p in model.parameters()], res
2222

2323
pool = Pool(min(len(batch), threads))
2424
pickled_fn = cloudpickle.dumps(run_grad_fn)

0 commit comments

Comments
 (0)