diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index afa484d81..b6d1622a8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -49,7 +49,7 @@ This is YES: } ``` -When there is descrepancy in a given source file (there are many origins for various bits of code and not all have been updated to what I consider current goal), please follow the style in a given file. +When there is discrepancy in a given source file (there are many origins for various bits of code and not all have been updated to what I consider current goal), please follow the style in a given file. In general, if you add new code, formatting it with black using the following options should result in a style that is compatible with the rest of the code base: diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 2cdcfd980..e4d1499f4 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -52,4 +52,5 @@ from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int +from .typing import LayerType, PadType from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/layers/typing.py b/timm/layers/typing.py new file mode 100644 index 000000000..593fa5cc8 --- /dev/null +++ b/timm/layers/typing.py @@ -0,0 +1,7 @@ +from typing import Callable, Tuple, Type, Union + +import torch + + +LayerType = Union[str, Callable, Type[torch.nn.Module]] +PadType = Union[str, int, Tuple[int, int]] diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index b5dbeaae3..1e3161d6b 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -11,6 +11,7 @@ import re from copy import deepcopy from functools import partial +from typing import Any, Dict, List import torch.nn as nn @@ -34,6 +35,8 @@ BN_EPS_TF_DEFAULT = 1e-3 _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) +BlockArgs = List[List[Dict[str, Any]]] + def get_bn_args_tf(): return _BN_ARGS_TF.copy() diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 8de94f7e2..21f0d3f1f 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -7,7 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman """ from functools import partial -from typing import List +from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn @@ -15,10 +15,10 @@ from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer +from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer from ._builder import build_model_with_cfg, pretrained_cfg_for_features from ._efficientnet_blocks import SqueezeExcite -from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ +from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from ._features import FeatureInfo, FeatureHooks from ._manipulate import checkpoint_seq @@ -44,23 +44,42 @@ class MobileNetV3(nn.Module): def __init__( self, - block_args, - num_classes=1000, - in_chans=3, - stem_size=16, - fix_stem=False, - num_features=1280, - head_bias=True, - pad_type='', - act_layer=None, - norm_layer=None, - se_layer=None, - se_from_exp=True, - round_chs_fn=round_channels, - drop_rate=0., - drop_path_rate=0., - global_pool='avg', + block_args: BlockArgs, + num_classes: int = 1000, + in_chans: int = 3, + stem_size: int = 16, + fix_stem: bool = False, + num_features: int = 1280, + head_bias: bool = True, + pad_type: PadType = '', + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + se_from_exp: bool = True, + round_chs_fn: Callable = round_channels, + drop_rate: float = 0., + drop_path_rate: float = 0., + global_pool: str = 'avg', ): + """ + Args: + block_args: Arguments for blocks of the network. + num_classes: Number of classes for classification head. + in_chans: Number of input image channels. + stem_size: Number of output channels of the initial stem convolution. + fix_stem: If True, don't scale stem by round_chs_fn. + num_features: Number of output channels of the conv head layer. + head_bias: If True, add a learnable bias to the conv head layer. + pad_type: Type of padding to use for convolution layers. + act_layer: Type of activation layer. + norm_layer: Type of normalization layer. + se_layer: Type of Squeeze-and-Excite layer. + se_from_exp: If True, calculate SE channel reduction from expanded mid channels. + round_chs_fn: Callable to round number of filters based on depth multiplier. + drop_rate: Dropout rate. + drop_path_rate: Stochastic depth rate. + global_pool: Type of pooling to use for global pooling features of the FC head. + """ super(MobileNetV3, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d @@ -110,28 +129,28 @@ def as_sequential(self): return nn.Sequential(*layers) @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False): return dict( stem=r'^conv_stem|bn1', blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)' ) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True): self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self): return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_stem(x) x = self.bn1(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -140,7 +159,7 @@ def forward_features(self, x): x = self.blocks(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: x = self.global_pool(x) x = self.conv_head(x) x = self.act2(x) @@ -151,7 +170,7 @@ def forward_head(self, x, pre_logits: bool = False): x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) return x @@ -166,22 +185,40 @@ class MobileNetV3Features(nn.Module): def __init__( self, - block_args, - out_indices=(0, 1, 2, 3, 4), - feature_location='bottleneck', - in_chans=3, - stem_size=16, - fix_stem=False, - output_stride=32, - pad_type='', - round_chs_fn=round_channels, - se_from_exp=True, - act_layer=None, - norm_layer=None, - se_layer=None, - drop_rate=0., - drop_path_rate=0., + block_args: BlockArgs, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + feature_location: str = 'bottleneck', + in_chans: int = 3, + stem_size: int = 16, + fix_stem: bool = False, + output_stride: int = 32, + pad_type: PadType = '', + round_chs_fn: Callable = round_channels, + se_from_exp: bool = True, + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + drop_rate: float = 0., + drop_path_rate: float = 0., ): + """ + Args: + block_args: Arguments for blocks of the network. + out_indices: Output from stages at indices. + feature_location: Location of feature before/after each block, must be in ['bottleneck', 'expansion'] + in_chans: Number of input image channels. + stem_size: Number of output channels of the initial stem convolution. + fix_stem: If True, don't scale stem by round_chs_fn. + output_stride: Output stride of the network. + pad_type: Type of padding to use for convolution layers. + round_chs_fn: Callable to round number of filters based on depth multiplier. + se_from_exp: If True, calculate SE channel reduction from expanded mid channels. + act_layer: Type of activation layer. + norm_layer: Type of normalization layer. + se_layer: Type of Squeeze-and-Excite layer. + drop_rate: Dropout rate. + drop_path_rate: Stochastic depth rate. + """ super(MobileNetV3Features, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d @@ -221,10 +258,10 @@ def __init__( self.feature_hooks = FeatureHooks(hooks, self.named_modules()) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True): self.grad_checkpointing = enable - def forward(self, x) -> List[torch.Tensor]: + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) @@ -246,7 +283,7 @@ def forward(self, x) -> List[torch.Tensor]: return list(out.values()) -def _create_mnv3(variant, pretrained=False, **kwargs): +def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV3: features_mode = '' model_cls = MobileNetV3 kwargs_filter = None @@ -272,7 +309,7 @@ def _create_mnv3(variant, pretrained=False, **kwargs): return model -def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): +def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: """Creates a MobileNet-V3 model. Ref impl: ? @@ -310,7 +347,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw return model -def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): +def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: """Creates a MobileNet-V3 model. Ref impl: ? @@ -407,7 +444,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg return model -def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): +def _gen_fbnetv3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): """ FBNetV3 Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining` - https://arxiv.org/abs/2006.02049 @@ -468,7 +505,7 @@ def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): return model -def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): +def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): """ LCNet Essentially a MobileNet-V3 crossed with a MobileNet-V1 @@ -506,7 +543,7 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): return model -def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): +def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): """ LCNet Essentially a MobileNet-V3 crossed with a MobileNet-V1 @@ -544,7 +581,7 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): return model -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', @@ -649,42 +686,42 @@ def _cfg(url='', **kwargs): @register_model -def mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_small_050(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_small_050', 0.50, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_rw(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_rw(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ if pretrained: # pretrained model trained with non-default BN epsilon @@ -694,7 +731,7 @@ def mobilenetv3_rw(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -703,7 +740,7 @@ def tf_mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -712,7 +749,7 @@ def tf_mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_large_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -721,7 +758,7 @@ def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -730,7 +767,7 @@ def tf_mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -739,7 +776,7 @@ def tf_mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_small_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -748,56 +785,56 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def fbnetv3_b(pretrained=False, **kwargs) -> MobileNetV3: +def fbnetv3_b(pretrained: bool = False, **kwargs) -> MobileNetV3: """ FBNetV3-B """ model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs) return model @register_model -def fbnetv3_d(pretrained=False, **kwargs) -> MobileNetV3: +def fbnetv3_d(pretrained: bool = False, **kwargs) -> MobileNetV3: """ FBNetV3-D """ model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs) return model @register_model -def fbnetv3_g(pretrained=False, **kwargs) -> MobileNetV3: +def fbnetv3_g(pretrained: bool = False, **kwargs) -> MobileNetV3: """ FBNetV3-G """ model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs) return model @register_model -def lcnet_035(pretrained=False, **kwargs) -> MobileNetV3: +def lcnet_035(pretrained: bool = False, **kwargs) -> MobileNetV3: """ PP-LCNet 0.35""" model = _gen_lcnet('lcnet_035', 0.35, pretrained=pretrained, **kwargs) return model @register_model -def lcnet_050(pretrained=False, **kwargs) -> MobileNetV3: +def lcnet_050(pretrained: bool = False, **kwargs) -> MobileNetV3: """ PP-LCNet 0.5""" model = _gen_lcnet('lcnet_050', 0.5, pretrained=pretrained, **kwargs) return model @register_model -def lcnet_075(pretrained=False, **kwargs) -> MobileNetV3: +def lcnet_075(pretrained: bool = False, **kwargs) -> MobileNetV3: """ PP-LCNet 1.0""" model = _gen_lcnet('lcnet_075', 0.75, pretrained=pretrained, **kwargs) return model @register_model -def lcnet_100(pretrained=False, **kwargs) -> MobileNetV3: +def lcnet_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ PP-LCNet 1.0""" model = _gen_lcnet('lcnet_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def lcnet_150(pretrained=False, **kwargs) -> MobileNetV3: +def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3: """ PP-LCNet 1.5""" model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs) return model diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 3c1197cb9..2549eb153 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -9,14 +9,15 @@ """ import math from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \ - get_act_layer, get_norm_layer, create_classifier +from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \ + get_attn, get_act_layer, get_norm_layer, create_classifier from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs, register_model_deprecations @@ -24,12 +25,12 @@ __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this -def get_padding(kernel_size, stride, dilation=1): +def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int: padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding -def create_aa(aa_layer, channels, stride=2, enable=True): +def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module: if not aa_layer or not enable: return nn.Identity() if issubclass(aa_layer, nn.AvgPool2d): @@ -43,22 +44,40 @@ class BasicBlock(nn.Module): def __init__( self, - inplanes, - planes, - stride=1, - downsample=None, - cardinality=1, - base_width=64, - reduce_first=1, - dilation=1, - first_dilation=None, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - attn_layer=None, - aa_layer=None, - drop_block=None, - drop_path=None, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + cardinality: int = 1, + base_width: int = 64, + reduce_first: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + aa_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[Type[nn.Module]] = None, + drop_path: Optional[nn.Module] = None, ): + """ + Args: + inplanes: Input channel dimensionality. + planes: Used to determine output channel dimensionalities. + stride: Stride used in convolution layers. + downsample: Optional downsample layer for residual path. + cardinality: Number of convolution groups. + base_width: Base width used to determine output channel dimensionality. + reduce_first: Reduction factor for first convolution output width of residual blocks. + dilation: Dilation rate for convolution layers. + first_dilation: Dilation rate for first convolution layer. + act_layer: Activation layer. + norm_layer: Normalization layer. + attn_layer: Attention layer. + aa_layer: Anti-aliasing layer. + drop_block: Class for DropBlock layer. + drop_path: Optional DropPath layer. + """ super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -92,7 +111,7 @@ def zero_init_last(self): if getattr(self.bn2, 'weight', None) is not None: nn.init.zeros_(self.bn2.weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.conv1(x) @@ -123,22 +142,40 @@ class Bottleneck(nn.Module): def __init__( self, - inplanes, - planes, - stride=1, - downsample=None, - cardinality=1, - base_width=64, - reduce_first=1, - dilation=1, - first_dilation=None, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - attn_layer=None, - aa_layer=None, - drop_block=None, - drop_path=None, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + cardinality: int = 1, + base_width: int = 64, + reduce_first: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + aa_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[Type[nn.Module]] = None, + drop_path: Optional[nn.Module] = None, ): + """ + Args: + inplanes: Input channel dimensionality. + planes: Used to determine output channel dimensionalities. + stride: Stride used in convolution layers. + downsample: Optional downsample layer for residual path. + cardinality: Number of convolution groups. + base_width: Base width used to determine output channel dimensionality. + reduce_first: Reduction factor for first convolution output width of residual blocks. + dilation: Dilation rate for convolution layers. + first_dilation: Dilation rate for first convolution layer. + act_layer: Activation layer. + norm_layer: Normalization layer. + attn_layer: Attention layer. + aa_layer: Anti-aliasing layer. + drop_block: Class for DropBlock layer. + drop_path: Optional DropPath layer. + """ super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) @@ -174,7 +211,7 @@ def zero_init_last(self): if getattr(self.bn3, 'weight', None) is not None: nn.init.zeros_(self.bn3.weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.conv1(x) @@ -205,14 +242,14 @@ def forward(self, x): def downsample_conv( - in_channels, - out_channels, - kernel_size, - stride=1, - dilation=1, - first_dilation=None, - norm_layer=None, -): + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + norm_layer: Optional[Type[nn.Module]] = None, +) -> nn.Module: norm_layer = norm_layer or nn.BatchNorm2d kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 @@ -226,14 +263,14 @@ def downsample_conv( def downsample_avg( - in_channels, - out_channels, - kernel_size, - stride=1, - dilation=1, - first_dilation=None, - norm_layer=None, -): + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + norm_layer: Optional[Type[nn.Module]] = None, +) -> nn.Module: norm_layer = norm_layer or nn.BatchNorm2d avg_stride = stride if dilation == 1 else 1 if stride == 1 and dilation == 1: @@ -249,7 +286,7 @@ def downsample_avg( ]) -def drop_blocks(drop_prob=0.): +def drop_blocks(drop_prob: float = 0.): return [ None, None, partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None, @@ -257,18 +294,18 @@ def drop_blocks(drop_prob=0.): def make_blocks( - block_fn, - channels, - block_repeats, - inplanes, - reduce_first=1, - output_stride=32, - down_kernel_size=1, - avg_down=False, - drop_block_rate=0., - drop_path_rate=0., + block_fn: Union[BasicBlock, Bottleneck], + channels: List[int], + block_repeats: List[int], + inplanes: int, + reduce_first: int = 1, + output_stride: int = 32, + down_kernel_size: int = 1, + avg_down: bool = False, + drop_block_rate: float = 0., + drop_path_rate: float = 0., **kwargs, -): +) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]: stages = [] feature_info = [] net_num_blocks = sum(block_repeats) @@ -356,28 +393,28 @@ class ResNet(nn.Module): def __init__( self, - block, - layers, - num_classes=1000, - in_chans=3, - output_stride=32, - global_pool='avg', - cardinality=1, - base_width=64, - stem_width=64, - stem_type='', - replace_stem_pool=False, - block_reduce_first=1, - down_kernel_size=1, - avg_down=False, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - aa_layer=None, - drop_rate=0.0, - drop_path_rate=0., - drop_block_rate=0., - zero_init_last=True, - block_args=None, + block: Union[BasicBlock, Bottleneck], + layers: List[int], + num_classes: int = 1000, + in_chans: int = 3, + output_stride: int = 32, + global_pool: str = 'avg', + cardinality: int = 1, + base_width: int = 64, + stem_width: int = 64, + stem_type: str = '', + replace_stem_pool: bool = False, + block_reduce_first: int = 1, + down_kernel_size: int = 1, + avg_down: bool = False, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[Type[nn.Module]] = None, + drop_rate: float = 0.0, + drop_path_rate: float = 0., + drop_block_rate: float = 0., + zero_init_last: bool = True, + block_args: Optional[Dict[str, Any]] = None, ): """ Args: @@ -490,7 +527,7 @@ def __init__( self.init_weights(zero_init_last=zero_init_last) @torch.jit.ignore - def init_weights(self, zero_init_last=True): + def init_weights(self, zero_init_last: bool = True): for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') @@ -500,23 +537,23 @@ def init_weights(self, zero_init_last=True): m.zero_init_last() @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False): matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)') return matcher @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self, name_only=False): + def get_classifier(self, name_only: bool = False): return 'fc' if name_only else self.fc def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) x = self.bn1(x) x = self.act1(x) @@ -531,19 +568,19 @@ def forward_features(self, x): x = self.layer4(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) return x if pre_logits else self.fc(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) return x -def _create_resnet(variant, pretrained=False, **kwargs): +def _create_resnet(variant, pretrained: bool = False, **kwargs) -> ResNet: return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) @@ -1204,7 +1241,7 @@ def _gcfg(url='', **kwargs): @register_model -def resnet10t(pretrained=False, **kwargs) -> ResNet: +def resnet10t(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-10-T model. """ model_args = dict(block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True) @@ -1212,7 +1249,7 @@ def resnet10t(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet14t(pretrained=False, **kwargs) -> ResNet: +def resnet14t(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-14-T model. """ model_args = dict(block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True) @@ -1220,7 +1257,7 @@ def resnet14t(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet18(pretrained=False, **kwargs) -> ResNet: +def resnet18(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-18 model. """ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2]) @@ -1228,7 +1265,7 @@ def resnet18(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet18d(pretrained=False, **kwargs) -> ResNet: +def resnet18d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-18-D model. """ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True) @@ -1236,7 +1273,7 @@ def resnet18d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet34(pretrained=False, **kwargs) -> ResNet: +def resnet34(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-34 model. """ model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3]) @@ -1244,7 +1281,7 @@ def resnet34(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet34d(pretrained=False, **kwargs) -> ResNet: +def resnet34d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-34-D model. """ model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True) @@ -1252,7 +1289,7 @@ def resnet34d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet26(pretrained=False, **kwargs) -> ResNet: +def resnet26(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-26 model. """ model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2]) @@ -1260,7 +1297,7 @@ def resnet26(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet26t(pretrained=False, **kwargs) -> ResNet: +def resnet26t(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-26-T model. """ model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True) @@ -1268,7 +1305,7 @@ def resnet26t(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet26d(pretrained=False, **kwargs) -> ResNet: +def resnet26d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-26-D model. """ model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True) @@ -1276,7 +1313,7 @@ def resnet26d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet50(pretrained=False, **kwargs) -> ResNet: +def resnet50(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50 model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) @@ -1284,7 +1321,7 @@ def resnet50(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet50c(pretrained=False, **kwargs) -> ResNet: +def resnet50c(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-C model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep') @@ -1292,7 +1329,7 @@ def resnet50c(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet50d(pretrained=False, **kwargs) -> ResNet: +def resnet50d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True) @@ -1300,7 +1337,7 @@ def resnet50d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet50s(pretrained=False, **kwargs) -> ResNet: +def resnet50s(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-S model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep') @@ -1308,7 +1345,7 @@ def resnet50s(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet50t(pretrained=False, **kwargs) -> ResNet: +def resnet50t(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-T model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True) @@ -1316,7 +1353,7 @@ def resnet50t(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet101(pretrained=False, **kwargs) -> ResNet: +def resnet101(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101 model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3]) @@ -1324,7 +1361,7 @@ def resnet101(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet101c(pretrained=False, **kwargs) -> ResNet: +def resnet101c(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-C model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep') @@ -1332,7 +1369,7 @@ def resnet101c(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet101d(pretrained=False, **kwargs) -> ResNet: +def resnet101d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-D model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True) @@ -1340,7 +1377,7 @@ def resnet101d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet101s(pretrained=False, **kwargs) -> ResNet: +def resnet101s(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-S model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep') @@ -1348,7 +1385,7 @@ def resnet101s(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet152(pretrained=False, **kwargs) -> ResNet: +def resnet152(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-152 model. """ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3]) @@ -1356,7 +1393,7 @@ def resnet152(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet152c(pretrained=False, **kwargs) -> ResNet: +def resnet152c(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-152-C model. """ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep') @@ -1364,7 +1401,7 @@ def resnet152c(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet152d(pretrained=False, **kwargs) -> ResNet: +def resnet152d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-152-D model. """ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True) @@ -1372,7 +1409,7 @@ def resnet152d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet152s(pretrained=False, **kwargs) -> ResNet: +def resnet152s(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-152-S model. """ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep') @@ -1380,7 +1417,7 @@ def resnet152s(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet200(pretrained=False, **kwargs) -> ResNet: +def resnet200(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-200 model. """ model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3]) @@ -1388,7 +1425,7 @@ def resnet200(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet200d(pretrained=False, **kwargs) -> ResNet: +def resnet200d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-200-D model. """ model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True) @@ -1396,7 +1433,7 @@ def resnet200d(pretrained=False, **kwargs) -> ResNet: @register_model -def wide_resnet50_2(pretrained=False, **kwargs) -> ResNet: +def wide_resnet50_2(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a Wide ResNet-50-2 model. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 @@ -1408,7 +1445,7 @@ def wide_resnet50_2(pretrained=False, **kwargs) -> ResNet: @register_model -def wide_resnet101_2(pretrained=False, **kwargs) -> ResNet: +def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a Wide ResNet-101-2 model. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 @@ -1419,7 +1456,7 @@ def wide_resnet101_2(pretrained=False, **kwargs) -> ResNet: @register_model -def resnet50_gn(pretrained=False, **kwargs) -> ResNet: +def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50 model w/ GroupNorm """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) @@ -1427,7 +1464,7 @@ def resnet50_gn(pretrained=False, **kwargs) -> ResNet: @register_model -def resnext50_32x4d(pretrained=False, **kwargs) -> ResNet: +def resnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt50-32x4d model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4) @@ -1435,7 +1472,7 @@ def resnext50_32x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnext50d_32x4d(pretrained=False, **kwargs) -> ResNet: +def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample """ model_args = dict( @@ -1445,7 +1482,7 @@ def resnext50d_32x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnext101_32x4d(pretrained=False, **kwargs) -> ResNet: +def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt-101 32x4d model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4) @@ -1453,7 +1490,7 @@ def resnext101_32x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnext101_32x8d(pretrained=False, **kwargs) -> ResNet: +def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt-101 32x8d model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8) @@ -1461,7 +1498,7 @@ def resnext101_32x8d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnext101_32x16d(pretrained=False, **kwargs) -> ResNet: +def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt-101 32x16d model """ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16) @@ -1469,7 +1506,7 @@ def resnext101_32x16d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnext101_32x32d(pretrained=False, **kwargs) -> ResNet: +def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt-101 32x32d model """ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32) @@ -1477,7 +1514,7 @@ def resnext101_32x32d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnext101_64x4d(pretrained=False, **kwargs) -> ResNet: +def resnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt101-64x4d model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4) @@ -1485,7 +1522,7 @@ def resnext101_64x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnet26t(pretrained=False, **kwargs) -> ResNet: +def ecaresnet26t(pretrained: bool = False, **kwargs) -> ResNet: """Constructs an ECA-ResNeXt-26-T model. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn. @@ -1497,7 +1534,7 @@ def ecaresnet26t(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnet50d(pretrained=False, **kwargs) -> ResNet: +def ecaresnet50d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model with eca. """ model_args = dict( @@ -1507,7 +1544,7 @@ def ecaresnet50d(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnet50d_pruned(pretrained=False, **kwargs) -> ResNet: +def ecaresnet50d_pruned(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model pruned with eca. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ @@ -1518,7 +1555,7 @@ def ecaresnet50d_pruned(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnet50t(pretrained=False, **kwargs) -> ResNet: +def ecaresnet50t(pretrained: bool = False, **kwargs) -> ResNet: """Constructs an ECA-ResNet-50-T model. Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn. """ @@ -1529,7 +1566,7 @@ def ecaresnet50t(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnetlight(pretrained=False, **kwargs) -> ResNet: +def ecaresnetlight(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D light model with eca. """ model_args = dict( @@ -1539,7 +1576,7 @@ def ecaresnetlight(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnet101d(pretrained=False, **kwargs) -> ResNet: +def ecaresnet101d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-D model with eca. """ model_args = dict( @@ -1549,7 +1586,7 @@ def ecaresnet101d(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnet101d_pruned(pretrained=False, **kwargs) -> ResNet: +def ecaresnet101d_pruned(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-D model pruned with eca. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ @@ -1560,7 +1597,7 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnet200d(pretrained=False, **kwargs) -> ResNet: +def ecaresnet200d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-200-D model with ECA. """ model_args = dict( @@ -1570,7 +1607,7 @@ def ecaresnet200d(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnet269d(pretrained=False, **kwargs) -> ResNet: +def ecaresnet269d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-269-D model with ECA. """ model_args = dict( @@ -1580,7 +1617,7 @@ def ecaresnet269d(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnext26t_32x4d(pretrained=False, **kwargs) -> ResNet: +def ecaresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs an ECA-ResNeXt-26-T model. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem. This model replaces SE module with the ECA module @@ -1592,7 +1629,7 @@ def ecaresnext26t_32x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def ecaresnext50t_32x4d(pretrained=False, **kwargs) -> ResNet: +def ecaresnext50t_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs an ECA-ResNeXt-50-T model. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem. This model replaces SE module with the ECA module @@ -1604,25 +1641,25 @@ def ecaresnext50t_32x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnet18(pretrained=False, **kwargs) -> ResNet: +def seresnet18(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se')) return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs)) @register_model -def seresnet34(pretrained=False, **kwargs) -> ResNet: +def seresnet34(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se')) return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs)) @register_model -def seresnet50(pretrained=False, **kwargs) -> ResNet: +def seresnet50(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se')) return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs)) @register_model -def seresnet50t(pretrained=False, **kwargs) -> ResNet: +def seresnet50t(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se')) @@ -1630,19 +1667,19 @@ def seresnet50t(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnet101(pretrained=False, **kwargs) -> ResNet: +def seresnet101(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se')) return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs)) @register_model -def seresnet152(pretrained=False, **kwargs) -> ResNet: +def seresnet152(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se')) return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs)) @register_model -def seresnet152d(pretrained=False, **kwargs) -> ResNet: +def seresnet152d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) @@ -1650,7 +1687,7 @@ def seresnet152d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnet200d(pretrained=False, **kwargs) -> ResNet: +def seresnet200d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-200-D model with SE attn. """ model_args = dict( @@ -1660,7 +1697,7 @@ def seresnet200d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnet269d(pretrained=False, **kwargs) -> ResNet: +def seresnet269d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-269-D model with SE attn. """ model_args = dict( @@ -1670,7 +1707,7 @@ def seresnet269d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnext26d_32x4d(pretrained=False, **kwargs) -> ResNet: +def seresnext26d_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a SE-ResNeXt-26-D model.` This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for combination of deep stem and avg_pool in downsample. @@ -1682,7 +1719,7 @@ def seresnext26d_32x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnext26t_32x4d(pretrained=False, **kwargs) -> ResNet: +def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a SE-ResNet-26-T model. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem. @@ -1694,7 +1731,7 @@ def seresnext26t_32x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnext26tn_32x4d(pretrained=False, **kwargs) -> ResNet: +def seresnext26tn_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a SE-ResNeXt-26-T model. NOTE I deprecated previous 't' model defs and replaced 't' with 'tn', this was the only tn model of note so keeping this def for backwards compat with any uses out there. Old 't' model is lost. @@ -1703,7 +1740,7 @@ def seresnext26tn_32x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnext50_32x4d(pretrained=False, **kwargs) -> ResNet: +def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, block_args=dict(attn_layer='se')) @@ -1711,7 +1748,7 @@ def seresnext50_32x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnext101_32x4d(pretrained=False, **kwargs) -> ResNet: +def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, block_args=dict(attn_layer='se')) @@ -1719,7 +1756,7 @@ def seresnext101_32x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnext101_32x8d(pretrained=False, **kwargs) -> ResNet: +def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, block_args=dict(attn_layer='se')) @@ -1727,7 +1764,7 @@ def seresnext101_32x8d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnext101d_32x8d(pretrained=False, **kwargs) -> ResNet: +def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, stem_width=32, stem_type='deep', avg_down=True, @@ -1736,7 +1773,7 @@ def seresnext101d_32x8d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnext101_64x4d(pretrained=False, **kwargs) -> ResNet: +def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, block_args=dict(attn_layer='se')) @@ -1744,7 +1781,7 @@ def seresnext101_64x4d(pretrained=False, **kwargs) -> ResNet: @register_model -def senet154(pretrained=False, **kwargs) -> ResNet: +def senet154(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se')) @@ -1752,7 +1789,7 @@ def senet154(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetblur18(pretrained=False, **kwargs) -> ResNet: +def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-18 model with blur anti-aliasing """ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d) @@ -1760,7 +1797,7 @@ def resnetblur18(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetblur50(pretrained=False, **kwargs) -> ResNet: +def resnetblur50(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50 model with blur anti-aliasing """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d) @@ -1768,7 +1805,7 @@ def resnetblur50(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetblur50d(pretrained=False, **kwargs) -> ResNet: +def resnetblur50d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model with blur anti-aliasing """ model_args = dict( @@ -1778,7 +1815,7 @@ def resnetblur50d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetblur101d(pretrained=False, **kwargs) -> ResNet: +def resnetblur101d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-D model with blur anti-aliasing """ model_args = dict( @@ -1788,7 +1825,7 @@ def resnetblur101d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetaa34d(pretrained=False, **kwargs) -> ResNet: +def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-34-D model w/ avgpool anti-aliasing """ model_args = dict( @@ -1797,7 +1834,7 @@ def resnetaa34d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetaa50(pretrained=False, **kwargs) -> ResNet: +def resnetaa50(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50 model with avgpool anti-aliasing """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d) @@ -1805,7 +1842,7 @@ def resnetaa50(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetaa50d(pretrained=False, **kwargs) -> ResNet: +def resnetaa50d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model with avgpool anti-aliasing """ model_args = dict( @@ -1815,7 +1852,7 @@ def resnetaa50d(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetaa101d(pretrained=False, **kwargs) -> ResNet: +def resnetaa101d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-D model with avgpool anti-aliasing """ model_args = dict( @@ -1825,7 +1862,7 @@ def resnetaa101d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnetaa50d(pretrained=False, **kwargs) -> ResNet: +def seresnetaa50d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing """ model_args = dict( @@ -1835,7 +1872,7 @@ def seresnetaa50d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnextaa101d_32x8d(pretrained=False, **kwargs) -> ResNet: +def seresnextaa101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing """ model_args = dict( @@ -1846,7 +1883,7 @@ def seresnextaa101d_32x8d(pretrained=False, **kwargs) -> ResNet: @register_model -def seresnextaa201d_32x8d(pretrained=False, **kwargs): +def seresnextaa201d_32x8d(pretrained: bool = False, **kwargs): """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing """ model_args = dict( @@ -1857,7 +1894,7 @@ def seresnextaa201d_32x8d(pretrained=False, **kwargs): @register_model -def resnetrs50(pretrained=False, **kwargs) -> ResNet: +def resnetrs50(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-RS-50 model. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs @@ -1870,7 +1907,7 @@ def resnetrs50(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetrs101(pretrained=False, **kwargs) -> ResNet: +def resnetrs101(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-RS-101 model. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs @@ -1883,7 +1920,7 @@ def resnetrs101(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetrs152(pretrained=False, **kwargs) -> ResNet: +def resnetrs152(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-RS-152 model. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs @@ -1896,7 +1933,7 @@ def resnetrs152(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetrs200(pretrained=False, **kwargs) -> ResNet: +def resnetrs200(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-RS-200 model. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs @@ -1909,7 +1946,7 @@ def resnetrs200(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetrs270(pretrained=False, **kwargs) -> ResNet: +def resnetrs270(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-RS-270 model. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs @@ -1923,7 +1960,7 @@ def resnetrs270(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetrs350(pretrained=False, **kwargs) -> ResNet: +def resnetrs350(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-RS-350 model. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs @@ -1936,7 +1973,7 @@ def resnetrs350(pretrained=False, **kwargs) -> ResNet: @register_model -def resnetrs420(pretrained=False, **kwargs) -> ResNet: +def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-RS-420 model Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs