Skip to content
This repository was archived by the owner on Jan 8, 2021. It is now read-only.

Commit 74d991a

Browse files
committed
Revert "Added test performance in training"
This reverts commit 7eb6412.
1 parent 7eb6412 commit 74d991a

File tree

3 files changed

+4
-30
lines changed

3 files changed

+4
-30
lines changed

Train.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from torch.utils import data
55

66

7-
def train(EPOCH, dataloader, test_dataloader, optimizer, loss_function, network, model_name, negative_loss, device, should_view, f):
7+
def train(EPOCH, dataloader, optimizer, loss_function, network, model_name, negative_loss, device, should_view, f):
88
for epoch in range(EPOCH):
9+
loss = 0
910
for step, (b_x, _) in enumerate(dataloader):
1011
b_x = b_x.to(device)
1112
if should_view:
@@ -27,22 +28,7 @@ def train(EPOCH, dataloader, test_dataloader, optimizer, loss_function, network,
2728
print('Saving models...')
2829
torch.save(network, model_name)
2930
print('Saving logs...')
30-
loss_sum = 0
31-
test_count = 0
32-
for _, (b_x, _) in enumerate(test_dataloader):
33-
b_x = b_x.to(device)
34-
if should_view:
35-
formatted_b_x = b_x.view(b_x.shape[0], -1)
36-
else:
37-
formatted_b_x = b_x
38-
output, _ = network(formatted_b_x)
39-
output = output.view(b_x.shape)
40-
loss = loss_function(output, b_x)
41-
if negative_loss:
42-
loss = - loss
43-
loss_sum += loss
44-
test_count += 1
45-
f.write("{} {}\n".format(epoch, loss_sum / test_count))
31+
f.write("{} {}\n".format(epoch, loss.cpu().data.numpy()))
4632

4733

4834
parser = argparse.ArgumentParser(
@@ -84,24 +70,12 @@ def train(EPOCH, dataloader, test_dataloader, optimizer, loss_function, network,
8470
transform = torchvision.transforms.ToTensor(),
8571
download = True,
8672
)
87-
test_data = torchvision.datasets.CIFAR10(
88-
root = './cifar10/',
89-
transform = torchvision.transforms.ToTensor(),
90-
download = True,
91-
train = False
92-
)
9373
else:
9474
train_data = torchvision.datasets.MNIST(
9575
root = './mnist/',
9676
transform = torchvision.transforms.ToTensor(),
9777
download = True,
9878
)
99-
test_data = torchvision.datasets.MNIST(
100-
root = './mnist/',
101-
transform = torchvision.transforms.ToTensor(),
102-
download = True,
103-
train = False
104-
)
10579
if network == "mlp":
10680
import mlp_network
10781
should_view = True
@@ -131,5 +105,5 @@ def train(EPOCH, dataloader, test_dataloader, optimizer, loss_function, network,
131105
loss_func = psnr.PSNR()
132106
train_loader = data.DataLoader(dataset = train_data, batch_size = BATCH_SIZE, shuffle = True)
133107
optimizer = torch.optim.Adam(network.parameters(), lr = LR)
134-
train(EPOCH, train_loader, test_data, optimizer, loss_func, network, model_name, negative_loss, args.device, should_view, f)
108+
train(EPOCH, train_loader, optimizer, loss_func, network, model_name, negative_loss, args.device, should_view, f)
135109
f.close()

mse_mlp_mnist.pkl

-2.35 MB
Binary file not shown.

train.log

Whitespace-only changes.

0 commit comments

Comments
 (0)