Skip to content

MNIST CNN #11437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions computer_vision/image_classification_example_mnist_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""
This Script contains the default EMNIST code for comparison.

Execution Details are available here:
https://www.kaggle.com/code/dipuk0506/mnist-cnn

The code is improved from:
nextjournal.com/gkoehler/pytorch-mnist

@author: Dipu Kabir
"""

import torch
import torchvision
import torch.nn as nn

Check failure on line 15 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0402)

computer_vision/image_classification_example_mnist_cnn.py:15:8: PLR0402 Use `from torch import nn` in lieu of alias

Check failure on line 15 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0402)

computer_vision/image_classification_example_mnist_cnn.py:15:8: PLR0402 Use `from torch import nn` in lieu of alias

Check failure on line 15 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0402)

computer_vision/image_classification_example_mnist_cnn.py:15:8: PLR0402 Use `from torch import nn` in lieu of alias
import torch.nn.functional as functional

Check failure on line 16 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0402)

computer_vision/image_classification_example_mnist_cnn.py:16:8: PLR0402 Use `from torch.nn import functional` in lieu of alias

Check failure on line 16 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0402)

computer_vision/image_classification_example_mnist_cnn.py:16:8: PLR0402 Use `from torch.nn import functional` in lieu of alias

Check failure on line 16 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0402)

computer_vision/image_classification_example_mnist_cnn.py:16:8: PLR0402 Use `from torch.nn import functional` in lieu of alias
import torch.optim as optim

Check failure on line 17 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0402)

computer_vision/image_classification_example_mnist_cnn.py:17:8: PLR0402 Use `from torch import optim` in lieu of alias

Check failure on line 17 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0402)

computer_vision/image_classification_example_mnist_cnn.py:17:8: PLR0402 Use `from torch import optim` in lieu of alias

Check failure on line 17 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0402)

computer_vision/image_classification_example_mnist_cnn.py:17:8: PLR0402 Use `from torch import optim` in lieu of alias

n_epochs = 8

Check failure on line 19 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

computer_vision/image_classification_example_mnist_cnn.py:13:1: I001 Import block is un-sorted or un-formatted

Check failure on line 19 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

computer_vision/image_classification_example_mnist_cnn.py:13:1: I001 Import block is un-sorted or un-formatted

Check failure on line 19 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

computer_vision/image_classification_example_mnist_cnn.py:13:1: I001 Import block is un-sorted or un-formatted
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 100


torch.backends.cudnn.enabled = False


train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
"/files/",
train=True,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=batch_size_train,
shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
"/files/",
train=False,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=batch_size_test,
shuffle=True,
)

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

print(example_data.shape)

import matplotlib.pyplot as plt

Check failure on line 67 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E402)

computer_vision/image_classification_example_mnist_cnn.py:67:1: E402 Module level import not at top of file

Check failure on line 67 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E402)

computer_vision/image_classification_example_mnist_cnn.py:67:1: E402 Module level import not at top of file

Check failure on line 67 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E402)

computer_vision/image_classification_example_mnist_cnn.py:67:1: E402 Module level import not at top of file

fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap="gray", interpolation="none")
plt.title("Ground Truth: {}".format(example_targets[i]))

Check failure on line 74 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP032)

computer_vision/image_classification_example_mnist_cnn.py:74:15: UP032 Use f-string instead of `format` call

Check failure on line 74 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP032)

computer_vision/image_classification_example_mnist_cnn.py:74:15: UP032 Use f-string instead of `format` call

Check failure on line 74 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP032)

computer_vision/image_classification_example_mnist_cnn.py:74:15: UP032 Use f-string instead of `format` call
plt.xticks([])
plt.yticks([])
fig

Check failure on line 77 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (B018)

computer_vision/image_classification_example_mnist_cnn.py:77:1: B018 Found useless expression. Either assign it to a variable or remove it.

Check failure on line 77 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (B018)

computer_vision/image_classification_example_mnist_cnn.py:77:1: B018 Found useless expression. Either assign it to a variable or remove it.

Check failure on line 77 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (B018)

computer_vision/image_classification_example_mnist_cnn.py:77:1: B018 Found useless expression. Either assign it to a variable or remove it.


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()

Check failure on line 82 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP008)

computer_vision/image_classification_example_mnist_cnn.py:82:14: UP008 Use `super()` instead of `super(__class__, self)`

Check failure on line 82 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP008)

computer_vision/image_classification_example_mnist_cnn.py:82:14: UP008 Use `super()` instead of `super(__class__, self)`

Check failure on line 82 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP008)

computer_vision/image_classification_example_mnist_cnn.py:82:14: UP008 Use `super()` instead of `super(__class__, self)`
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = functional.relu(functional.max_pool2d(self.conv1(x), 2))
x = functional.relu(functional.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = functional.relu(self.fc1(x))
x = functional.dropout(x, training=self.training)
x = self.fc2(x)
return functional.log_softmax(x)


network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)


train_losses = []
train_counter = []
test_losses = []
test_counter = [i * len(train_loader.dataset) for i in range(n_epochs + 1)]


def train(epoch):
network.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = network(data)
loss = functional.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)

Check failure on line 125 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP032)

computer_vision/image_classification_example_mnist_cnn.py:119:17: UP032 Use f-string instead of `format` call

Check failure on line 125 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP032)

computer_vision/image_classification_example_mnist_cnn.py:119:17: UP032 Use f-string instead of `format` call

Check failure on line 125 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP032)

computer_vision/image_classification_example_mnist_cnn.py:119:17: UP032 Use f-string instead of `format` call
)
train_losses.append(loss.item())
train_counter.append(
(batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset))
)


def test():
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
test_loss += functional.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
print(
"\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss,
correct,
len(test_loader.dataset),
100.0 * correct / len(test_loader.dataset),
)

Check failure on line 151 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP032)

computer_vision/image_classification_example_mnist_cnn.py:146:9: UP032 Use f-string instead of `format` call

Check failure on line 151 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP032)

computer_vision/image_classification_example_mnist_cnn.py:146:9: UP032 Use f-string instead of `format` call

Check failure on line 151 in computer_vision/image_classification_example_mnist_cnn.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP032)

computer_vision/image_classification_example_mnist_cnn.py:146:9: UP032 Use f-string instead of `format` call
)


test()
for epoch in range(1, n_epochs + 1):
train(epoch)
test()
# %%

fig = plt.figure()
plt.plot(train_counter, train_losses, color="blue")
plt.scatter(test_counter, test_losses, color="red")
plt.legend(["Train Loss", "Test Loss"], loc="upper right")
plt.xlabel("number of training examples seen")
plt.ylabel("negative log likelihood loss")
fig
Loading