Skip to content

Commit d07977f

Browse files
authored
Update mnist_classifier.py
modified mnist_classifier.py
1 parent 488d56c commit d07977f

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

computer_vision/mnist_classifier.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -9,42 +9,42 @@
99

1010
import sys
1111
import torch
12-
import torch.nn
12+
import torch.nn as n
1313
import torchvision.datasets as dset
1414
import torchvision.transforms
1515
from torch.autograd import Variable
1616
import torch.nn.functional as f
1717
import torch.optim
1818

1919

20-
class AlexNet(nn.Module):
20+
class AlexNet(n.Module):
2121
def __init__(self, num):
2222
super().__init__()
23-
self.feature = nn.Sequential(
23+
self.feature = n.Sequential(
2424
# Define feature extractor here...
25-
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=1),
26-
nn.ReLU(inplace=True),
27-
nn.Conv2d(32, 64, kernel_size=3, padding=1),
28-
nn.ReLU(inplace=True),
29-
nn.MaxPool2d(kernel_size=2, stride=2),
30-
nn.Conv2d(64, 96, kernel_size=3, padding=1),
31-
nn.ReLU(inplace=True),
32-
nn.Conv2d(96, 64, kernel_size=3, padding=1),
33-
nn.ReLU(inplace=True),
34-
nn.Conv2d(64, 32, kernel_size=3, padding=1),
35-
nn.ReLU(inplace=True),
36-
nn.MaxPool2d(kernel_size=2, stride=1),
25+
n.Conv2d(1, 32, kernel_size=5, stride=1, padding=1),
26+
n.ReLU(inplace=True),
27+
n.Conv2d(32, 64, kernel_size=3, padding=1),
28+
n.ReLU(inplace=True),
29+
n.MaxPool2d(kernel_size=2, stride=2),
30+
n.Conv2d(64, 96, kernel_size=3, padding=1),
31+
n.ReLU(inplace=True),
32+
n.Conv2d(96, 64, kernel_size=3, padding=1),
33+
n.ReLU(inplace=True),
34+
n.Conv2d(64, 32, kernel_size=3, padding=1),
35+
n.ReLU(inplace=True),
36+
n.MaxPool2d(kernel_size=2, stride=1),
3737
)
3838

39-
self.classifier = nn.Sequential(
39+
self.classifier = n.Sequential(
4040
# Define classifier here...
41-
nn.Dropout(),
42-
nn.Linear(32 * 12 * 12, 2048),
43-
nn.ReLU(inplace=True),
44-
nn.Dropout(),
45-
nn.Linear(2048, 1024),
46-
nn.ReLU(inplace=True),
47-
nn.Linear(1024, 10),
41+
n.Dropout(),
42+
n.Linear(32 * 12 * 12, 2048),
43+
n.ReLU(inplace=True),
44+
n.Dropout(),
45+
n.Linear(2048, 1024),
46+
n.ReLU(inplace=True),
47+
n.Linear(1024, 10),
4848
)
4949

5050
def forward(self, x):

0 commit comments

Comments
 (0)