Skip to content
This repository was archived by the owner on Jan 8, 2021. It is now read-only.

Commit 7aa6a62

Browse files
committed
Fixed cifar10
1 parent 18d8398 commit 7aa6a62

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

Train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def train(EPOCH, dataloader, optimizer, loss_function, network, model_name, nega
2424
print('Epoch: ', epoch + 1, '| Step: ', step + 1, '| Train loss: %.4f' % loss.cpu().data.numpy())
2525
if step % SAVE_STEP == 1:
2626
print('Saving models...')
27-
torch.save(network, model_name)
27+
torch.save(network, "{}-{}-step{}.pkl".format(model_name, epoch, step))
2828
print('Saving models...')
29-
torch.save(network, model_name)
29+
torch.save(network, "{}-{}.pkl".format(model_name, epoch))
3030
print('Saving logs...')
3131
f.write("{} {}\n".format(epoch, loss.cpu().data.numpy()))
3232

run.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ train_cnn_mnist() {
99
python Train.py --epoch 20 --dataset mnist --network cnn --loss_func "$2" --model_name "$3" 1 "$1"
1010
}
1111
train_mlp_cifar() {
12-
python Train.py --epoch 20 --dataset cifar10 --network mlp --loss_func "$2" --model_name "$3" 784 "$1"
12+
python Train.py --epoch 20 --dataset cifar10 --network mlp --loss_func "$2" --model_name "$3" 3072 "$1"
1313
}
1414
train_cnn_cifar() {
15-
python Train.py --epoch 20 --dataset cifar10 --network cnn --loss_func "$2" --model_name "$3" 1 "$1"
15+
python Train.py --epoch 20 --dataset cifar10 --network cnn --loss_func "$2" --model_name "$3" 3 "$1"
1616
}

0 commit comments

Comments
 (0)