Skip to content

Commit 13987cc

Browse files
authored
Create Generative_Adversarial_Network_MNIST.py
1 parent e9e7c96 commit 13987cc

File tree

1 file changed

+249
-0
lines changed

1 file changed

+249
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
%matplotlib inline
2+
3+
import numpy as np
4+
import torch
5+
import matplotlib.pyplot as plt
6+
from torchvision import datasets
7+
import torchvision.transforms as transforms
8+
9+
# number of subprocesses to use for data loading
10+
num_workers = 0
11+
# how many samples per batch to load
12+
batch_size = 64
13+
14+
# convert data to torch.FloatTensor
15+
transform = transforms.ToTensor()
16+
17+
# get the training datasets
18+
train_data = datasets.MNIST(root='data', train=True,
19+
download=True, transform=transform)
20+
21+
# prepare data loader
22+
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
23+
num_workers=num_workers)
24+
25+
import torch.nn as nn
26+
import torch.nn.functional as F
27+
28+
# Creating Generator and Discriminator for GAN
29+
30+
class discriminator(nn.Module):
31+
def __init__(self,input_size,output_size,hidden_dim):
32+
super(discriminator,self).__init__()
33+
34+
#defining the layers of the discriminator
35+
self.fc1 = nn.Linear(input_size,hidden_dim*4)
36+
self.fc2 = nn.Linear(hidden_dim*4,hidden_dim*2)
37+
self.fc3 = nn.Linear(hidden_dim*2,hidden_dim)
38+
#final fully connected layer
39+
self.fc4 = nn.Linear(hidden_dim,output_size)
40+
#dropout layer
41+
self.dropout = nn.Dropout(0.2)
42+
43+
def forward(self,x):
44+
# pass x through all layers
45+
# apply leaky relu activation to all hidden layers
46+
x = x.view(-1,28*28) #flattening the image
47+
x = F.leaky_relu(self.fc1(x),0.2)
48+
x = self.dropout(x)
49+
x = F.leaky_relu(self.fc2(x),0.2)
50+
x = self.dropout(x)
51+
x = F.leaky_relu(self.fc3(x),0.2)
52+
x = self.dropout(x)
53+
x_out = self.fc4(x)
54+
55+
return x_out
56+
57+
class generator(nn.Module):
58+
59+
def __init__(self, input_size, output_size,hidden_dim):
60+
super(generator, self).__init__()
61+
62+
# define all layers
63+
self.fc1 = nn.Linear(input_size,hidden_dim)
64+
self.fc2 = nn.Linear(hidden_dim,hidden_dim*2)
65+
self.fc3 = nn.Linear(hidden_dim*2,hidden_dim*4)
66+
#final layer
67+
self.fc4 = nn.Linear(hidden_dim*4,output_size)
68+
#dropout layer
69+
self.dropout = nn.Dropout(0.2)
70+
71+
72+
def forward(self, x):
73+
# pass x through all layers
74+
# final layer should have tanh applied
75+
x = F.leaky_relu(self.fc1(x),0.2)
76+
x = self.dropout(x)
77+
x = F.leaky_relu(self.fc2(x),0.2)
78+
x = self.dropout(x)
79+
x = F.leaky_relu(self.fc3(x),0.2)
80+
x = self.dropout(x)
81+
x_out = F.tanh(self.fc4(x))
82+
return x_out
83+
84+
# Calculate losses
85+
def real_loss(D_out, smooth=False):
86+
# compare logits to real labels
87+
# smooth labels if smooth=True
88+
#puting it into cuda
89+
batch_size = D_out.size(0)
90+
if smooth:
91+
labels = torch.ones(batch_size).cuda()*0.9
92+
else:
93+
labels = torch.ones(batch_size).cuda()
94+
criterion = nn.BCEWithLogitsLoss()
95+
loss = criterion(D_out.squeeze(),labels)
96+
return loss
97+
98+
def fake_loss(D_out):
99+
# compare logits to fake labels
100+
batch_size = D_out.size(0)
101+
labels = torch.zeros(batch_size).cuda()
102+
criterion = nn.BCEWithLogitsLoss()
103+
loss = criterion(D_out.squeeze(),labels)
104+
return loss
105+
106+
# Discriminator hyperparams
107+
# Size of input image to discriminator (28*28)
108+
input_size = 784
109+
# Size of discriminator output (real or fake)
110+
d_output_size = 1
111+
# Size of *last* hidden layer in the discriminator
112+
d_hidden_size = 32
113+
114+
# Generator hyperparams
115+
116+
# Size of latent vector to give to generator
117+
z_size = 100
118+
# Size of discriminator output (generated image)
119+
g_output_size = 784
120+
# Size of *first* hidden layer in the generator
121+
g_hidden_size = 32
122+
123+
# instantiate discriminator and generator and put it in cuda mode
124+
D = discriminator(input_size, d_output_size,d_hidden_size).cuda()
125+
G = generator(z_size, g_output_size, g_hidden_size).cuda()
126+
127+
import pickle as pkl
128+
129+
# training hyperparams
130+
num_epochs = 40
131+
132+
# keep track of loss and generated, "fake" samples
133+
samples = []
134+
losses = []
135+
136+
print_every = 400
137+
138+
# Get some fixed data for sampling. These are images that are held
139+
# constant throughout training, and allow us to inspect the model's performance
140+
sample_size=16
141+
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
142+
fixed_z = torch.from_numpy(fixed_z).float().cuda()
143+
144+
# train the network
145+
D.train()
146+
G.train()
147+
for epoch in range(num_epochs):
148+
149+
for batch_i, (real_images, _) in enumerate(train_loader):
150+
151+
batch_size = real_images.size(0)
152+
153+
## Important rescaling step ##
154+
real_images = (real_images*2 - 1).cuda() # rescale input images from [0,1) to [-1, 1)
155+
156+
# ============================================
157+
# TRAIN THE DISCRIMINATOR
158+
# ============================================
159+
d_optimizer.zero_grad()
160+
# 1. Train with real images
161+
162+
# Compute the discriminator losses on real images
163+
# use smoothed labels
164+
D_real = D(real_images)
165+
d_real_loss = real_loss(D_real,smooth=True)
166+
# 2. Train with fake images
167+
168+
# Generate fake images
169+
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
170+
z = torch.from_numpy(z).float().cuda()
171+
fake_images = G(z)
172+
173+
# Compute the discriminator losses on fake images
174+
D_fake = D(fake_images)
175+
d_fake_loss = fake_loss(D_fake)
176+
# add up real and fake losses and perform backprop
177+
d_loss = d_real_loss + d_fake_loss
178+
d_loss.backward()
179+
d_optimizer.step()
180+
181+
182+
# =========================================
183+
# TRAIN THE GENERATOR
184+
# =========================================
185+
g_optimizer.zero_grad()
186+
# 1. Train with fake images and flipped labels
187+
188+
# Generate fake images
189+
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
190+
z = torch.from_numpy(z).float().cuda()
191+
fake_images = G(z)
192+
# Compute the discriminator losses on fake images
193+
# using flipped labels!
194+
D_fake = D(fake_images)
195+
# perform backprop
196+
g_loss = real_loss(D_fake)
197+
g_loss.backward()
198+
g_optimizer.step()
199+
200+
201+
# Print some loss stats
202+
if batch_i % print_every == 0:
203+
# print discriminator and generator loss
204+
print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(
205+
epoch+1, num_epochs, d_loss.item(), g_loss.item()))
206+
207+
208+
## AFTER EACH EPOCH##
209+
# append discriminator loss and generator loss
210+
losses.append((d_loss.item(), g_loss.item()))
211+
212+
# generate and save sample, fake images
213+
G.eval() # eval mode for generating samples
214+
samples_z = G(fixed_z)
215+
samples.append(samples_z)
216+
G.train() # back to train mode
217+
218+
219+
# Save training generator samples
220+
with open('train_samples.pkl', 'wb') as f:
221+
pkl.dump(samples, f)
222+
223+
#ploting Discriminator and Generator loss
224+
fig, ax = plt.subplots()
225+
losses = np.array(losses)
226+
plt.plot(losses.T[0], label='Discriminator')
227+
plt.plot(losses.T[1], label='Generator')
228+
plt.title("Training Losses")
229+
plt.legend()
230+
plt.show()
231+
232+
233+
#Viewing the results of the GAN
234+
def view_samples(epoch, samples):
235+
236+
fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
237+
fig.suptitle("Generated Digits")
238+
for ax, img in zip(axes.flatten(), samples[epoch]):
239+
img = img.detach().cpu().numpy()
240+
ax.xaxis.set_visible(False)
241+
ax.yaxis.set_visible(False)
242+
im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
243+
244+
with open('train_samples.pkl', 'rb') as f:
245+
samples = pkl.load(f)
246+
247+
view_samples(-1,samples)
248+
plt.show()
249+

0 commit comments

Comments
 (0)