Skip to content
This repository was archived by the owner on Jan 8, 2021. It is now read-only.

Commit 1e2bbc1

Browse files
committed
Add files
1 parent 936dec3 commit 1e2bbc1

File tree

9 files changed

+357
-0
lines changed

9 files changed

+357
-0
lines changed

Test.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import argparse
2+
import torchvision
3+
import torch
4+
from PIL import Image
5+
6+
7+
def test(data, network):
8+
output, _ = network(data)
9+
return output
10+
11+
12+
parser = argparse.ArgumentParser(
13+
description = 'Testing utility for paper "Investigation on different loss function in image autoencoder"')
14+
parser.add_argument('target', type = str, help = "Input image")
15+
parser.add_argument('output', type = str, help = "Output image")
16+
parser.add_argument('--model_name', type = str, dest = "model_name", help = "Model name for saving & restoring model",
17+
default = "saved.pkl")
18+
parser.add_argument('--disable-cuda', action='store_true',
19+
help='Disable CUDA')
20+
args = parser.parse_args()
21+
filename = args.target
22+
output = args.output
23+
model_name = args.model_name
24+
network = torch.load(model_name)
25+
if not args.disable_cuda and torch.cuda.is_available():
26+
args.device = torch.device('cuda')
27+
network.cuda()
28+
else:
29+
args.device = torch.device('cpu')
30+
data = Image.open(filename)
31+
data = torchvision.transforms.ToTensor()(data).unsqueeze_(0).to(args.device)
32+
result = test(data, network)
33+
result = torchvision.transforms.ToPILImage()(result.squeeze_(0).to("cpu"))
34+
result.save(output)

Train.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import argparse
2+
import torchvision
3+
import torch
4+
from torch.utils import data
5+
6+
7+
def train(EPOCH, dataloader, optimizer, loss_function, network, model_name, negative_loss, device, should_view, f):
8+
for epoch in range(EPOCH):
9+
loss = 0
10+
for step, (b_x, _) in enumerate(dataloader):
11+
b_x = b_x.to(device)
12+
if should_view:
13+
formatted_b_x = b_x.view(b_x.shape[0], -1)
14+
else:
15+
formatted_b_x = b_x
16+
output, _ = network(formatted_b_x)
17+
output = output.view(b_x.shape)
18+
loss = loss_function(output, b_x)
19+
if negative_loss:
20+
loss = - loss
21+
optimizer.zero_grad()
22+
loss.backward()
23+
optimizer.step()
24+
print('Epoch: ', epoch + 1, '| Step: ', step + 1, '| Train loss: %.4f' % loss.data.numpy())
25+
if step % SAVE_STEP == 1:
26+
print('Saving models...')
27+
torch.save(network, model_name)
28+
print('Saving models...')
29+
torch.save(network, model_name)
30+
print('Saving logs...')
31+
f.write("{} {}\n".format(epoch, loss.data.numpy()))
32+
33+
34+
parser = argparse.ArgumentParser(
35+
description = 'Training utility for paper "Investigation on different loss function in image autoencoder"')
36+
parser.add_argument('--save_at', type = int, dest = "SAVE_STEP", help = "Save network at how many steps (default: 10) "
37+
",", default = 10)
38+
parser.add_argument('--epoch', '-e', type = int, dest = "EPOCH", help = "Epoch for training", default = 10)
39+
parser.add_argument('--batch_size', '-b', type = int, dest = "BATCH_SIZE", help = "Batch size for dataloader",
40+
default = 8192)
41+
parser.add_argument('--learning_rate', '-l', type = float, dest = "LR", help = "Learning rate", default = 0.001)
42+
parser.add_argument('--dataset', type = str, dest = "dataset", default = "cifar10", choices = ['cifar10', 'mnist'])
43+
parser.add_argument('--network', type = str, dest = "network", default = "mlp", choices = ['mlp', 'conv'])
44+
parser.add_argument('origional_size', type = int, help = "Size of origional image. Please note that for ConvNet, it is the number of images' channel.")
45+
parser.add_argument('bottleneck', type = int, help = "Size of bottle neck. Please note that for ConvNet, its \"bottleneck\" is input * origional size in one channel / 4")
46+
parser.add_argument('--loss_func', type = str, dest = "loss_function", default = "mse",
47+
choices = ['mse', 'l1', 'ssim', 'psnr'])
48+
parser.add_argument('--model_name', type = str, dest = "model_name", help = "Model name for saving & restoring model",
49+
default = "saved.pkl")
50+
parser.add_argument('--disable-cuda', action='store_true',
51+
help='Disable CUDA')
52+
parser.add_argument('--log', type = str, default = "train.log",dest = "log_file", help = "Plase to store logs")
53+
args = parser.parse_args()
54+
SAVE_STEP = args.SAVE_STEP
55+
EPOCH = args.EPOCH
56+
BATCH_SIZE = args.BATCH_SIZE
57+
LR = args.LR
58+
dataset = args.dataset
59+
network = args.network
60+
orgsize = args.origional_size
61+
bottleneck = args.bottleneck
62+
loss_function = args.loss_function
63+
model_name = args.model_name
64+
log_file = args.log_file
65+
f = open(log_file, "a")
66+
f.write("x y\n")
67+
if dataset == "cifar10":
68+
train_data = torchvision.datasets.CIFAR10(
69+
root = './cifar10/',
70+
transform = torchvision.transforms.ToTensor(),
71+
download = True,
72+
)
73+
else:
74+
train_data = torchvision.datasets.MNIST(
75+
root = './mnist/',
76+
transform = torchvision.transforms.ToTensor(),
77+
download = True,
78+
)
79+
if network == "mlp":
80+
import mlp_network
81+
should_view = True
82+
network = mlp_network.autoencoder(orgsize, bottleneck)
83+
else:
84+
import conv_network
85+
should_view = False
86+
network = conv_network.autoencoder(orgsize, bottleneck)
87+
if not args.disable_cuda and torch.cuda.is_available():
88+
args.device = torch.device('cuda')
89+
network.cuda()
90+
else:
91+
args.device = torch.device('cpu')
92+
negative_loss = False
93+
if loss_function == "mse":
94+
loss_func = torch.nn.MSELoss()
95+
elif loss_function == "l1":
96+
loss_func = torch.nn.L1Loss()
97+
elif loss_function == "l1s":
98+
loss_func = torch.nn.SmoothL1Loss()
99+
elif loss_function == "ssim":
100+
import pytorch_ssim
101+
loss_func = pytorch_ssim.SSIM()
102+
negative_loss = True
103+
else:
104+
import psnr
105+
loss_func = psnr.PSNR()
106+
train_loader = data.DataLoader(dataset = train_data, batch_size = BATCH_SIZE, shuffle = True)
107+
optimizer = torch.optim.Adam(network.parameters(), lr = LR)
108+
train(EPOCH, train_loader, optimizer, loss_func, network, model_name, negative_loss, args.device, should_view, f)
109+
f.close()

conv_network.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from torch import nn
2+
3+
4+
class encoder(nn.Module):
5+
def __init__(self, in_features, out_features):
6+
super(encoder, self).__init__()
7+
self.conv1 = nn.Conv2d(
8+
in_channels = in_features,
9+
out_channels = 32,
10+
kernel_size = 3,
11+
padding = 1
12+
)
13+
self.conv2 = nn.Conv2d(
14+
in_channels = 32,
15+
out_channels = 64,
16+
kernel_size = 3,
17+
padding = 1
18+
)
19+
self.pool = nn.MaxPool2d(2, 0)
20+
self.conv3 = nn.Conv2d(
21+
in_channels = 64,
22+
out_channels = out_features,
23+
kernel_size = 3,
24+
padding = 1
25+
)
26+
27+
def forward(self, x):
28+
return self.conv3(self.pool(self.conv2(self.conv1(x))))
29+
30+
class decoder(nn.Module):
31+
def __init__(self, in_features, out_features):
32+
super(decoder, self).__init__()
33+
self.conv1 = nn.Conv2d(
34+
in_channels = in_features,
35+
out_channels = 64,
36+
kernel_size = 3,
37+
padding = 1
38+
)
39+
self.pool = nn.Upsample(scale_factor = 2)
40+
self.conv2 = nn.Conv2d(
41+
in_channels = 64,
42+
out_channels = 32,
43+
kernel_size = 3,
44+
padding = 1
45+
)
46+
self.conv3 = nn.Conv2d(
47+
in_channels = 32,
48+
out_channels = out_features,
49+
kernel_size = 3,
50+
padding = 1
51+
)
52+
53+
def forward(self, x):
54+
return self.conv3(self.conv2(self.pool(self.conv1(x))))
55+
56+
57+
class autoencoder(nn.Module):
58+
def __init__(self, in_feature, bottleneck):
59+
super(autoencoder, self).__init__()
60+
self.encoder = encoder(in_feature, bottleneck)
61+
self.decoder = decoder(bottleneck, in_feature)
62+
63+
def forward(self, x):
64+
encoded = self.encoder(x)
65+
return self.decoder(encoded), encoded

mlp_network.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from torch import nn
2+
3+
4+
class coder(nn.Module):
5+
def __init__(self, in_features, out_features):
6+
super(coder, self).__init__()
7+
self.linear1 = nn.Linear(in_features, out_features)
8+
9+
def forward(self, x):
10+
return self.linear1(x)
11+
12+
13+
class autoencoder(nn.Module):
14+
def __init__(self, in_feature, bottleneck):
15+
super(autoencoder, self).__init__()
16+
self.encoder = coder(in_feature, bottleneck)
17+
self.decoder = coder(bottleneck, in_feature)
18+
19+
def forward(self, x):
20+
encoded = self.encoder(x)
21+
return self.decoder(encoded), encoded

psnr.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
from math import log10
3+
4+
5+
class PSNR(torch.nn.Module):
6+
def __init__(self):
7+
super(PSNR, self).__init__()
8+
self.criterion = torch.nn.MSELoss()
9+
def forward(self, prediction, target):
10+
mse = self.criterion(prediction, target)
11+
psnr = 10 * log10(1 / mse.item())
12+
return psnr

pytorch_ssim/__init__.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch.autograd import Variable
4+
import numpy as np
5+
from math import exp
6+
7+
def gaussian(window_size, sigma):
8+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9+
return gauss/gauss.sum()
10+
11+
def create_window(window_size, channel):
12+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15+
return window
16+
17+
def _ssim(img1, img2, window, window_size, channel, size_average = True):
18+
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19+
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20+
21+
mu1_sq = mu1.pow(2)
22+
mu2_sq = mu2.pow(2)
23+
mu1_mu2 = mu1*mu2
24+
25+
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26+
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27+
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28+
29+
C1 = 0.01**2
30+
C2 = 0.03**2
31+
32+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33+
34+
if size_average:
35+
return ssim_map.mean()
36+
else:
37+
return ssim_map.mean(1).mean(1).mean(1)
38+
39+
class SSIM(torch.nn.Module):
40+
def __init__(self, window_size = 11, size_average = True):
41+
super(SSIM, self).__init__()
42+
self.window_size = window_size
43+
self.size_average = size_average
44+
self.channel = 1
45+
self.window = create_window(window_size, self.channel)
46+
47+
def forward(self, img1, img2):
48+
(_, channel, _, _) = img1.size()
49+
50+
if channel == self.channel and self.window.data.type() == img1.data.type():
51+
window = self.window
52+
else:
53+
window = create_window(self.window_size, channel)
54+
55+
if img1.is_cuda:
56+
window = window.cuda(img1.get_device())
57+
window = window.type_as(img1)
58+
59+
self.window = window
60+
self.channel = channel
61+
62+
63+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
64+
65+
def ssim(img1, img2, window_size = 11, size_average = True):
66+
(_, channel, _, _) = img1.size()
67+
window = create_window(window_size, channel)
68+
69+
if img1.is_cuda:
70+
window = window.cuda(img1.get_device())
71+
window = window.type_as(img1)
72+
73+
return _ssim(img1, img2, window, window_size, channel, size_average)

requirements.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
torchvision>=0.6.0
2+
torch>=1.5.0
3+
Pillow>=7.1.2
4+
setuptools>=46.4.0
5+
numpy>=1.18.4

run.sh

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
test_img() {
3+
python Test.py --model_name "$1" "$2" "$3"
4+
}
5+
train_mlp_mnist() {
6+
python Train.py --epoch 20 --dataset mnist --network mlp --loss_func "$2" --model_name "$3" 784 "$1"
7+
}
8+
train_cnn_mnist() {
9+
python Train.py --epoch 20 --dataset mnist --network cnn --loss_func "$2" --model_name "$3" 1 "$1"
10+
}
11+
train_mlp_cifar() {
12+
python Train.py --epoch 20 --dataset cifar10 --network mlp --loss_func "$2" --model_name "$3" 784 "$1"
13+
}
14+
train_cnn_cifar() {
15+
python Train.py --epoch 20 --dataset cifar10 --network cnn --loss_func "$2" --model_name "$3" 1 "$1"
16+
}

train.log

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
x y
2+
x y
3+
x y
4+
x y
5+
x y
6+
x y
7+
x y
8+
x y
9+
x y
10+
x y
11+
x y
12+
x y
13+
x y
14+
x y
15+
x y
16+
x y
17+
x y
18+
x y
19+
x y
20+
x y
21+
x y
22+
0 0.14592687785625458

0 commit comments

Comments
 (0)