Skip to content

Commit e86cbb0

Browse files
committed
batching helpers
1 parent 22fcce6 commit e86cbb0

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

mnist_train.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import torch
66
import torch.optim as optim
77

8+
from reptile_gen.batching import batched_grad
89
from reptile_gen.mnist import iterate_mini_datasets
910
from reptile_gen.model import MNISTModel
1011
from reptile_gen.reptile import reptile_grad
1112

1213
OUT_PATH = 'model.pt'
13-
AVG_SIZE = 1000
14+
AVG_SIZE = 20
1415
META_BATCH = 50
1516

1617

@@ -23,17 +24,19 @@ def main():
2324
mini_batches = iterate_mini_datasets()
2425
last_n = []
2526
for i in itertools.count():
26-
inputs, outputs = next(mini_batches)
27-
losses = reptile_grad(model, inputs, outputs, opt)
27+
outer_opt.zero_grad()
28+
29+
def grad_fn(model, x, y):
30+
return reptile_grad(model, x, y, opt)
31+
32+
batch = [next(mini_batches) for _ in range(META_BATCH)]
33+
losses = batched_grad(model, grad_fn, batch)
2834
loss = np.mean(losses)
2935
last_n.append(loss)
3036
last_n = last_n[-AVG_SIZE:]
31-
if i % META_BATCH == 0:
32-
outer_opt.step()
33-
outer_opt.zero_grad()
34-
torch.save(model.state_dict(), OUT_PATH)
35-
print('step %d: loss=%f last_%d=%f' %
36-
(i//META_BATCH, np.mean(losses), AVG_SIZE, np.mean(last_n)))
37+
outer_opt.step()
38+
torch.save(model.state_dict(), OUT_PATH)
39+
print('step %d: loss=%f last_%d=%f' % (i, np.mean(losses), AVG_SIZE, np.mean(last_n)))
3740

3841

3942
if __name__ == '__main__':

reptile_gen/batching.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from multiprocessing import Pool, set_start_method
2+
3+
import cloudpickle
4+
import torch
5+
6+
7+
def batched_grad(model, grad_fn, batch, threads=1, device='cpu'):
8+
set_start_method('spawn', force=True)
9+
10+
model_class = model.__class__
11+
model_dict = {x: y.cpu().numpy() for x, y in model.state_dict().items()}
12+
13+
def run_grad_fn(inputs, outputs):
14+
model = model_class()
15+
state = {x: torch.from_numpy(y) for x, y in model_dict.items()}
16+
model.load_state_dict(state)
17+
d = torch.device(device)
18+
if device != 'cpu':
19+
model.to(d)
20+
res = grad_fn(model, inputs.to(d), outputs.to(d))
21+
return [p.grad for p in model.parameters()], res
22+
23+
pool = Pool(min(len(batch), threads))
24+
pickled_fn = cloudpickle.dumps(run_grad_fn)
25+
raw_results = pool.map(call_pickled_fn, [(pickled_fn, x) for x in batch])
26+
grads, results = list(zip(*raw_results))
27+
pool.close()
28+
29+
for grad in grads:
30+
for p, g in zip(model.parameters(), grad):
31+
if p.grad is None:
32+
p.grad = g
33+
else:
34+
p.grad.add_(g)
35+
36+
return results
37+
38+
39+
def call_pickled_fn(data_args):
40+
data, args = data_args
41+
res = cloudpickle.loads(data)(*args)
42+
return res

0 commit comments

Comments
 (0)