Skip to content

Commit d5f1525

Browse files
a-r-r-o-wrwightman
andcommitted
include suggestions from review
Co-Authored-By: Ross Wightman <[email protected]>
1 parent 5f14bdd commit d5f1525

File tree

3 files changed

+12
-16
lines changed

3 files changed

+12
-16
lines changed

timm/layers/typing.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import functools
2-
import types
3-
from typing import Tuple, Union
1+
from typing import Callable, Tuple, Type, Union
42

5-
import torch.nn
3+
import torch
64

75

8-
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
6+
LayerType = Union[str, Callable, Type[torch.nn.Module]]
97
PadType = Union[str, int, Tuple[int, int]]

timm/models/mobilenetv3.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch
1313
import torch.nn as nn
1414
import torch.nn.functional as F
15-
from torch import Tensor
1615
from torch.utils.checkpoint import checkpoint
1716

1817
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@@ -151,7 +150,7 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
151150
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
152151
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
153152

154-
def forward_features(self, x: Tensor) -> Tensor:
153+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
155154
x = self.conv_stem(x)
156155
x = self.bn1(x)
157156
if self.grad_checkpointing and not torch.jit.is_scripting():
@@ -160,7 +159,7 @@ def forward_features(self, x: Tensor) -> Tensor:
160159
x = self.blocks(x)
161160
return x
162161

163-
def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
162+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
164163
x = self.global_pool(x)
165164
x = self.conv_head(x)
166165
x = self.act2(x)
@@ -171,7 +170,7 @@ def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
171170
x = F.dropout(x, p=self.drop_rate, training=self.training)
172171
return self.classifier(x)
173172

174-
def forward(self, x: Tensor) -> Tensor:
173+
def forward(self, x: torch.Tensor) -> torch.Tensor:
175174
x = self.forward_features(x)
176175
x = self.forward_head(x)
177176
return x
@@ -262,7 +261,7 @@ def __init__(
262261
def set_grad_checkpointing(self, enable: bool = True):
263262
self.grad_checkpointing = enable
264263

265-
def forward(self, x: Tensor) -> List[Tensor]:
264+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
266265
x = self.conv_stem(x)
267266
x = self.bn1(x)
268267
x = self.act1(x)

timm/models/resnet.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import torch
1515
import torch.nn as nn
1616
import torch.nn.functional as F
17-
from torch import Tensor
1817

1918
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2019
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
@@ -112,7 +111,7 @@ def zero_init_last(self):
112111
if getattr(self.bn2, 'weight', None) is not None:
113112
nn.init.zeros_(self.bn2.weight)
114113

115-
def forward(self, x: Tensor) -> Tensor:
114+
def forward(self, x: torch.Tensor) -> torch.Tensor:
116115
shortcut = x
117116

118117
x = self.conv1(x)
@@ -212,7 +211,7 @@ def zero_init_last(self):
212211
if getattr(self.bn3, 'weight', None) is not None:
213212
nn.init.zeros_(self.bn3.weight)
214213

215-
def forward(self, x: Tensor) -> Tensor:
214+
def forward(self, x: torch.Tensor) -> torch.Tensor:
216215
shortcut = x
217216

218217
x = self.conv1(x)
@@ -554,7 +553,7 @@ def reset_classifier(self, num_classes, global_pool='avg'):
554553
self.num_classes = num_classes
555554
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
556555

557-
def forward_features(self, x: Tensor) -> Tensor:
556+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
558557
x = self.conv1(x)
559558
x = self.bn1(x)
560559
x = self.act1(x)
@@ -569,13 +568,13 @@ def forward_features(self, x: Tensor) -> Tensor:
569568
x = self.layer4(x)
570569
return x
571570

572-
def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
571+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
573572
x = self.global_pool(x)
574573
if self.drop_rate:
575574
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
576575
return x if pre_logits else self.fc(x)
577576

578-
def forward(self, x: Tensor) -> Tensor:
577+
def forward(self, x: torch.Tensor) -> torch.Tensor:
579578
x = self.forward_features(x)
580579
x = self.forward_head(x)
581580
return x

0 commit comments

Comments
 (0)