4
4
from torch .utils import data
5
5
6
6
7
- def train (EPOCH , dataloader , test_dataloader , optimizer , loss_function , network , model_name , negative_loss , device , should_view , f ):
7
+ def train (EPOCH , dataloader , optimizer , loss_function , network , model_name , negative_loss , device , should_view , f ):
8
8
for epoch in range (EPOCH ):
9
+ loss = 0
9
10
for step , (b_x , _ ) in enumerate (dataloader ):
10
11
b_x = b_x .to (device )
11
12
if should_view :
@@ -27,22 +28,7 @@ def train(EPOCH, dataloader, test_dataloader, optimizer, loss_function, network,
27
28
print ('Saving models...' )
28
29
torch .save (network , model_name )
29
30
print ('Saving logs...' )
30
- loss_sum = 0
31
- test_count = 0
32
- for _ , (b_x , _ ) in enumerate (test_dataloader ):
33
- b_x = b_x .to (device )
34
- if should_view :
35
- formatted_b_x = b_x .view (b_x .shape [0 ], - 1 )
36
- else :
37
- formatted_b_x = b_x
38
- output , _ = network (formatted_b_x )
39
- output = output .view (b_x .shape )
40
- loss = loss_function (output , b_x )
41
- if negative_loss :
42
- loss = - loss
43
- loss_sum += loss
44
- test_count += 1
45
- f .write ("{} {}\n " .format (epoch , loss_sum / test_count ))
31
+ f .write ("{} {}\n " .format (epoch , loss .cpu ().data .numpy ()))
46
32
47
33
48
34
parser = argparse .ArgumentParser (
@@ -84,24 +70,12 @@ def train(EPOCH, dataloader, test_dataloader, optimizer, loss_function, network,
84
70
transform = torchvision .transforms .ToTensor (),
85
71
download = True ,
86
72
)
87
- test_data = torchvision .datasets .CIFAR10 (
88
- root = './cifar10/' ,
89
- transform = torchvision .transforms .ToTensor (),
90
- download = True ,
91
- train = False
92
- )
93
73
else :
94
74
train_data = torchvision .datasets .MNIST (
95
75
root = './mnist/' ,
96
76
transform = torchvision .transforms .ToTensor (),
97
77
download = True ,
98
78
)
99
- test_data = torchvision .datasets .MNIST (
100
- root = './mnist/' ,
101
- transform = torchvision .transforms .ToTensor (),
102
- download = True ,
103
- train = False
104
- )
105
79
if network == "mlp" :
106
80
import mlp_network
107
81
should_view = True
@@ -131,5 +105,5 @@ def train(EPOCH, dataloader, test_dataloader, optimizer, loss_function, network,
131
105
loss_func = psnr .PSNR ()
132
106
train_loader = data .DataLoader (dataset = train_data , batch_size = BATCH_SIZE , shuffle = True )
133
107
optimizer = torch .optim .Adam (network .parameters (), lr = LR )
134
- train (EPOCH , train_loader , test_data , optimizer , loss_func , network , model_name , negative_loss , args .device , should_view , f )
108
+ train (EPOCH , train_loader , optimizer , loss_func , network , model_name , negative_loss , args .device , should_view , f )
135
109
f .close ()
0 commit comments