Skip to content
This repository was archived by the owner on Jan 8, 2021. It is now read-only.

Commit 051f446

Browse files
committed
Added test performance in training
1 parent 74d991a commit 051f446

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

Train.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from torch.utils import data
55

66

7-
def train(EPOCH, dataloader, optimizer, loss_function, network, model_name, negative_loss, device, should_view, f):
7+
def train(EPOCH, dataloader, test_dataloader, optimizer, loss_function, network, model_name, negative_loss, device, should_view, f):
88
for epoch in range(EPOCH):
9-
loss = 0
109
for step, (b_x, _) in enumerate(dataloader):
1110
b_x = b_x.to(device)
1211
if should_view:
@@ -28,7 +27,22 @@ def train(EPOCH, dataloader, optimizer, loss_function, network, model_name, nega
2827
print('Saving models...')
2928
torch.save(network, model_name)
3029
print('Saving logs...')
31-
f.write("{} {}\n".format(epoch, loss.cpu().data.numpy()))
30+
loss_sum = 0
31+
test_count = 0
32+
for _, (b_x, _) in enumerate(test_dataloader):
33+
b_x = b_x.to(device)
34+
if should_view:
35+
formatted_b_x = b_x.view(b_x.shape[0], -1)
36+
else:
37+
formatted_b_x = b_x
38+
output, _ = network(formatted_b_x)
39+
output = output.view(b_x.shape)
40+
loss = loss_function(output, b_x)
41+
if negative_loss:
42+
loss = - loss
43+
loss_sum += loss
44+
test_count += 1
45+
f.write("{} {}\n".format(epoch, loss_sum / test_count))
3246

3347

3448
parser = argparse.ArgumentParser(
@@ -70,12 +84,24 @@ def train(EPOCH, dataloader, optimizer, loss_function, network, model_name, nega
7084
transform = torchvision.transforms.ToTensor(),
7185
download = True,
7286
)
87+
test_data = torchvision.datasets.CIFAR10(
88+
root = './cifar10/',
89+
transform = torchvision.transforms.ToTensor(),
90+
download = True,
91+
train = False
92+
)
7393
else:
7494
train_data = torchvision.datasets.MNIST(
7595
root = './mnist/',
7696
transform = torchvision.transforms.ToTensor(),
7797
download = True,
7898
)
99+
test_data = torchvision.datasets.MNIST(
100+
root = './mnist/',
101+
transform = torchvision.transforms.ToTensor(),
102+
download = True,
103+
train = False
104+
)
79105
if network == "mlp":
80106
import mlp_network
81107
should_view = True
@@ -105,5 +131,5 @@ def train(EPOCH, dataloader, optimizer, loss_function, network, model_name, nega
105131
loss_func = psnr.PSNR()
106132
train_loader = data.DataLoader(dataset = train_data, batch_size = BATCH_SIZE, shuffle = True)
107133
optimizer = torch.optim.Adam(network.parameters(), lr = LR)
108-
train(EPOCH, train_loader, optimizer, loss_func, network, model_name, negative_loss, args.device, should_view, f)
134+
train(EPOCH, train_loader, test_data, optimizer, loss_func, network, model_name, negative_loss, args.device, should_view, f)
109135
f.close()

0 commit comments

Comments
 (0)