|
9 | 9 |
|
10 | 10 | import sys
|
11 | 11 | import torch
|
12 |
| -import torch.nn |
| 12 | +import torch.nn as n |
13 | 13 | import torchvision.datasets as dset
|
14 | 14 | import torchvision.transforms
|
15 | 15 | from torch.autograd import Variable
|
16 | 16 | import torch.nn.functional as f
|
17 | 17 | import torch.optim
|
18 | 18 |
|
19 | 19 |
|
20 |
| -class AlexNet(nn.Module): |
| 20 | +class AlexNet(n.Module): |
21 | 21 | def __init__(self, num):
|
22 | 22 | super().__init__()
|
23 |
| - self.feature = nn.Sequential( |
| 23 | + self.feature = n.Sequential( |
24 | 24 | # Define feature extractor here...
|
25 |
| - nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=1), |
26 |
| - nn.ReLU(inplace=True), |
27 |
| - nn.Conv2d(32, 64, kernel_size=3, padding=1), |
28 |
| - nn.ReLU(inplace=True), |
29 |
| - nn.MaxPool2d(kernel_size=2, stride=2), |
30 |
| - nn.Conv2d(64, 96, kernel_size=3, padding=1), |
31 |
| - nn.ReLU(inplace=True), |
32 |
| - nn.Conv2d(96, 64, kernel_size=3, padding=1), |
33 |
| - nn.ReLU(inplace=True), |
34 |
| - nn.Conv2d(64, 32, kernel_size=3, padding=1), |
35 |
| - nn.ReLU(inplace=True), |
36 |
| - nn.MaxPool2d(kernel_size=2, stride=1), |
| 25 | + n.Conv2d(1, 32, kernel_size=5, stride=1, padding=1), |
| 26 | + n.ReLU(inplace=True), |
| 27 | + n.Conv2d(32, 64, kernel_size=3, padding=1), |
| 28 | + n.ReLU(inplace=True), |
| 29 | + n.MaxPool2d(kernel_size=2, stride=2), |
| 30 | + n.Conv2d(64, 96, kernel_size=3, padding=1), |
| 31 | + n.ReLU(inplace=True), |
| 32 | + n.Conv2d(96, 64, kernel_size=3, padding=1), |
| 33 | + n.ReLU(inplace=True), |
| 34 | + n.Conv2d(64, 32, kernel_size=3, padding=1), |
| 35 | + n.ReLU(inplace=True), |
| 36 | + n.MaxPool2d(kernel_size=2, stride=1), |
37 | 37 | )
|
38 | 38 |
|
39 |
| - self.classifier = nn.Sequential( |
| 39 | + self.classifier = n.Sequential( |
40 | 40 | # Define classifier here...
|
41 |
| - nn.Dropout(), |
42 |
| - nn.Linear(32 * 12 * 12, 2048), |
43 |
| - nn.ReLU(inplace=True), |
44 |
| - nn.Dropout(), |
45 |
| - nn.Linear(2048, 1024), |
46 |
| - nn.ReLU(inplace=True), |
47 |
| - nn.Linear(1024, 10), |
| 41 | + n.Dropout(), |
| 42 | + n.Linear(32 * 12 * 12, 2048), |
| 43 | + n.ReLU(inplace=True), |
| 44 | + n.Dropout(), |
| 45 | + n.Linear(2048, 1024), |
| 46 | + n.ReLU(inplace=True), |
| 47 | + n.Linear(1024, 10), |
48 | 48 | )
|
49 | 49 |
|
50 | 50 | def forward(self, x):
|
|
0 commit comments