Skip to content
This repository was archived by the owner on Jan 8, 2021. It is now read-only.

Commit 2c6f4df

Browse files
committed
Save log before training
1 parent b92594f commit 2c6f4df

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

Train.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@
55

66

77
def train(EPOCH, dataloader, test_dataloader, optimizer, loss_function, network, model_name, negative_loss, device, should_view, f):
8+
print('Saving logs...')
9+
loss_sum = 0
10+
test_count = 0
11+
for _, (b_x, _) in enumerate(test_dataloader):
12+
b_x = b_x.to(device)
13+
if should_view:
14+
formatted_b_x = b_x.view(b_x.shape[0], -1)
15+
else:
16+
formatted_b_x = b_x
17+
output, _ = network(formatted_b_x)
18+
output = output.view(b_x.shape)
19+
loss = loss_function(output, b_x)
20+
if negative_loss:
21+
loss = - loss
22+
loss_sum += loss
23+
test_count += 1
24+
f.write("{} {}\n".format(0, loss_sum / test_count))
825
for epoch in range(EPOCH):
926
for step, (b_x, _) in enumerate(dataloader):
1027
b_x = b_x.to(device)
@@ -42,7 +59,7 @@ def train(EPOCH, dataloader, test_dataloader, optimizer, loss_function, network,
4259
loss = - loss
4360
loss_sum += loss
4461
test_count += 1
45-
f.write("{} {}\n".format(epoch, loss_sum / test_count))
62+
f.write("{} {}\n".format(epoch + 1, loss_sum / test_count))
4663

4764

4865
parser = argparse.ArgumentParser(

0 commit comments

Comments
 (0)