Skip to content

Commit aa8733d

Browse files
author
Azure Pipelines
committed
Merge remote-tracking branch 'origin/main' into publication
2 parents 0f20211 + d21c66d commit aa8733d

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

course_UvA-DL/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
numpy <2.0 # needed for older Torch
1+
numpy <3.0 # needed for older Torch
22
torch >=1.8.1,<2.5
33
pytorch-lightning >=2.0,<2.4
44
torchmetrics >=1.0,<1.5

lightning_examples/basic-gan/gan.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def block(in_feat, out_feat, normalize=True):
8888
layers = [nn.Linear(in_feat, out_feat)]
8989
if normalize:
9090
layers.append(nn.BatchNorm1d(out_feat, 0.8))
91-
layers.append(nn.LeakyReLU(0.2, inplace=True))
91+
layers.append(nn.LeakyReLU(0.01, inplace=True))
9292
return layers
9393

9494
self.model = nn.Sequential(
@@ -193,15 +193,15 @@ def training_step(self, batch):
193193
# log sampled images
194194
sample_imgs = self.generated_imgs[:6]
195195
grid = torchvision.utils.make_grid(sample_imgs)
196-
self.logger.experiment.add_image("generated_images", grid, 0)
196+
self.logger.experiment.add_image("train/generated_images", grid, self.current_epoch)
197197

198198
# ground truth result (ie: all fake)
199199
# put on GPU because we created this tensor inside training_loop
200200
valid = torch.ones(imgs.size(0), 1)
201201
valid = valid.type_as(imgs)
202202

203203
# adversarial loss is binary cross-entropy
204-
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
204+
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
205205
self.log("g_loss", g_loss, prog_bar=True)
206206
self.manual_backward(g_loss)
207207
optimizer_g.step()
@@ -222,7 +222,7 @@ def training_step(self, batch):
222222
fake = torch.zeros(imgs.size(0), 1)
223223
fake = fake.type_as(imgs)
224224

225-
fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)
225+
fake_loss = self.adversarial_loss(self.discriminator(self.generated_imgs.detach()), fake)
226226

227227
# discriminator loss is the average of these
228228
d_loss = (real_loss + fake_loss) / 2
@@ -232,6 +232,9 @@ def training_step(self, batch):
232232
optimizer_d.zero_grad()
233233
self.untoggle_optimizer(optimizer_d)
234234

235+
def validation_step(self, batch, batch_idx):
236+
pass
237+
235238
def configure_optimizers(self):
236239
lr = self.hparams.lr
237240
b1 = self.hparams.b1
@@ -247,7 +250,7 @@ def on_validation_epoch_end(self):
247250
# log sampled images
248251
sample_imgs = self(z)
249252
grid = torchvision.utils.make_grid(sample_imgs)
250-
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
253+
self.logger.experiment.add_image("validation/generated_images", grid, self.current_epoch)
251254

252255

253256
# %%
@@ -263,4 +266,4 @@ def on_validation_epoch_end(self):
263266
# %%
264267
# Start tensorboard.
265268
# %load_ext tensorboard
266-
# %tensorboard --logdir lightning_logs/
269+
# %tensorboard --logdir lightning_logs/ --samples_per_plugin=images=60

lightning_examples/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
numpy <2.0 # needed for older Torch
1+
numpy <3.0 # needed for older Torch
22
torch>=1.8.1, <2.5
33
pytorch-lightning >=2.0,<2.4
44
torchmetrics>=1.0, <1.5

0 commit comments

Comments
 (0)