4
4
from torch .utils import data
5
5
6
6
7
- def train (EPOCH , dataloader , optimizer , loss_function , network , model_name , negative_loss , device , should_view , f ):
7
+ def train (EPOCH , dataloader , test_dataloader , optimizer , loss_function , network , model_name , negative_loss , device , should_view , f ):
8
8
for epoch in range (EPOCH ):
9
- loss = 0
10
9
for step , (b_x , _ ) in enumerate (dataloader ):
11
10
b_x = b_x .to (device )
12
11
if should_view :
@@ -28,7 +27,22 @@ def train(EPOCH, dataloader, optimizer, loss_function, network, model_name, nega
28
27
print ('Saving models...' )
29
28
torch .save (network , model_name )
30
29
print ('Saving logs...' )
31
- f .write ("{} {}\n " .format (epoch , loss .cpu ().data .numpy ()))
30
+ loss_sum = 0
31
+ test_count = 0
32
+ for _ , (b_x , _ ) in enumerate (test_dataloader ):
33
+ b_x = b_x .to (device )
34
+ if should_view :
35
+ formatted_b_x = b_x .view (b_x .shape [0 ], - 1 )
36
+ else :
37
+ formatted_b_x = b_x
38
+ output , _ = network (formatted_b_x )
39
+ output = output .view (b_x .shape )
40
+ loss = loss_function (output , b_x )
41
+ if negative_loss :
42
+ loss = - loss
43
+ loss_sum += loss
44
+ test_count += 1
45
+ f .write ("{} {}\n " .format (epoch , loss_sum / test_count ))
32
46
33
47
34
48
parser = argparse .ArgumentParser (
@@ -70,12 +84,24 @@ def train(EPOCH, dataloader, optimizer, loss_function, network, model_name, nega
70
84
transform = torchvision .transforms .ToTensor (),
71
85
download = True ,
72
86
)
87
+ test_data = torchvision .datasets .CIFAR10 (
88
+ root = './cifar10/' ,
89
+ transform = torchvision .transforms .ToTensor (),
90
+ download = True ,
91
+ train = False
92
+ )
73
93
else :
74
94
train_data = torchvision .datasets .MNIST (
75
95
root = './mnist/' ,
76
96
transform = torchvision .transforms .ToTensor (),
77
97
download = True ,
78
98
)
99
+ test_data = torchvision .datasets .MNIST (
100
+ root = './mnist/' ,
101
+ transform = torchvision .transforms .ToTensor (),
102
+ download = True ,
103
+ train = False
104
+ )
79
105
if network == "mlp" :
80
106
import mlp_network
81
107
should_view = True
@@ -105,5 +131,5 @@ def train(EPOCH, dataloader, optimizer, loss_function, network, model_name, nega
105
131
loss_func = psnr .PSNR ()
106
132
train_loader = data .DataLoader (dataset = train_data , batch_size = BATCH_SIZE , shuffle = True )
107
133
optimizer = torch .optim .Adam (network .parameters (), lr = LR )
108
- train (EPOCH , train_loader , optimizer , loss_func , network , model_name , negative_loss , args .device , should_view , f )
134
+ train (EPOCH , train_loader , test_data , optimizer , loss_func , network , model_name , negative_loss , args .device , should_view , f )
109
135
f .close ()
0 commit comments