Skip to content

A collection of fixes, model experiments, etc #880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Oct 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8e11da0
Add experimental RegNetZ(ish) models for training / perf trials.
rwightman Sep 23, 2021
da06cc6
ResNetV2 seems to work best without zero_init residual
rwightman Sep 23, 2021
515121c
Use reshape instead of view in std_conv, causing issues in recent PyT…
rwightman Sep 23, 2021
f8a63a3
Add worker_init_fn to loader for numpy seed per worker
rwightman Sep 23, 2021
5d6983c
Batch validate a list of files if model is a text file with model per…
rwightman Sep 23, 2021
0387e60
Update binary cross ent impl to use thresholding as an option (conver…
rwightman Sep 23, 2021
6478bcd
Fix regnetz_d conv layer name, use inception mean/std
rwightman Sep 26, 2021
80075b0
Add worker_seeding arg to allow selecting old vs updated data loader …
rwightman Sep 28, 2021
b81e79a
Fix bottleneck attn transpose typo, hopefully these train better now..
rwightman Sep 28, 2021
0ca687f
Make 'regnetz' model experiments closer to actual RegNetZ, bottleneck…
rwightman Oct 1, 2021
d657e2c
Remove dead code line from efficientnet
rwightman Oct 1, 2021
b49630a
Add relative pos embed option to LambdaLayer, fix last transpose/resh…
rwightman Oct 1, 2021
b1c2e3e
Match rel_pos_indices attr rename in conv branch
rwightman Oct 1, 2021
d9abfa4
Make broadcast_buffers disable its own flag for now (needs more testi…
rwightman Oct 1, 2021
007bc39
Some halo and bottleneck attn code cleanup, add halonet50ts weights, …
rwightman Oct 2, 2021
b2094f4
support bits checkpoints in avg/load
rwightman Oct 4, 2021
6449550
Add updated lambda resnet26 and botnet26 checkpoints with fixes applied
rwightman Oct 4, 2021
cc9bedf
Add initial ResNet Strikes Back weights for ResNet50 and ResNetV2-50 …
rwightman Oct 4, 2021
da0d39b
Update default crop_pct for byoanet
rwightman Oct 4, 2021
93901e9
Version bump to 0.5.0 for pending release post RSB and ATTN updates
rwightman Oct 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions avg_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def checkpoint_metric(checkpoint_path):
metric = None
if 'metric' in checkpoint:
metric = checkpoint['metric']
elif 'metrics' in checkpoint and 'metric_name' in checkpoint:
metrics = checkpoint['metrics']
print(metrics)
metric = metrics[checkpoint['metric_name']]
return metric


Expand Down
27 changes: 24 additions & 3 deletions timm/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf

Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2021 Ross Wightman
"""
import random
from functools import partial
from typing import Callable

import torch.utils.data
import numpy as np
Expand Down Expand Up @@ -125,6 +128,22 @@ def mixup_enabled(self, x):
self.loader.collate_fn.mixup_enabled = x


def _worker_init(worker_id, worker_seeding='all'):
worker_info = torch.utils.data.get_worker_info()
assert worker_info.id == worker_id
if isinstance(worker_seeding, Callable):
seed = worker_seeding(worker_info)
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed % (2 ** 32 - 1))
else:
assert worker_seeding in ('all', 'part')
# random / torch seed already called in dataloader iter class w/ worker_info.seed
# to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
if worker_seeding == 'all':
np.random.seed(worker_info.seed % (2 ** 32 - 1))


def create_loader(
dataset,
input_size,
Expand Down Expand Up @@ -156,6 +175,7 @@ def create_loader(
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
worker_seeding='all',
):
re_num_splits = 0
if re_split:
Expand Down Expand Up @@ -202,7 +222,6 @@ def create_loader(
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate

loader_class = torch.utils.data.DataLoader

if use_multi_epochs_loader:
loader_class = MultiEpochsDataLoader

Expand All @@ -214,7 +233,9 @@ def create_loader(
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=is_training,
persistent_workers=persistent_workers)
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
persistent_workers=persistent_workers
)
try:
loader = loader_class(dataset, **loader_args)
except TypeError as e:
Expand Down
2 changes: 1 addition & 1 deletion timm/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
from .binary_cross_entropy import DenseBinaryCrossEntropy
from .binary_cross_entropy import BinaryCrossEntropy
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from .jsd import JsdCrossEntropy
50 changes: 37 additions & 13 deletions timm/loss/binary_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,47 @@
""" Binary Cross Entropy w/ a few extras

Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class DenseBinaryCrossEntropy(nn.Module):
""" BCE using one-hot from dense targets w/ label smoothing
class BinaryCrossEntropy(nn.Module):
""" BCE with optional one-hot from dense targets, label smoothing, thresholding
NOTE for experiments comparing CE to BCE /w label smoothing, may remove
"""
def __init__(self, smoothing=0.1):
super(DenseBinaryCrossEntropy, self).__init__()
def __init__(
self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None,
reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None):
super(BinaryCrossEntropy, self).__init__()
assert 0. <= smoothing < 1.0
self.smoothing = smoothing
self.bce = nn.BCEWithLogitsLoss()
self.target_threshold = target_threshold
self.reduction = reduction
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)

def forward(self, x, target):
num_classes = x.shape[-1]
off_value = self.smoothing / num_classes
on_value = 1. - self.smoothing + off_value
target = target.long().view(-1, 1)
target = torch.full(
(target.size()[0], num_classes), off_value, device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
return self.bce(x, target)
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
assert x.shape[0] == target.shape[0]
if target.shape != x.shape:
# NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
num_classes = x.shape[-1]
# FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ
off_value = self.smoothing / num_classes
on_value = 1. - self.smoothing + off_value
target = target.long().view(-1, 1)
target = torch.full(
(target.size()[0], num_classes),
off_value,
device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
if self.target_threshold is not None:
# Make target 0, or 1 if threshold set
target = target.gt(self.target_threshold).to(dtype=target.dtype)
return F.binary_cross_entropy_with_logits(
x, target,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
16 changes: 8 additions & 8 deletions timm/loss/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
""" Cross Entropy w/ smoothing or soft targets

Hacked together by / Copyright 2021 Ross Wightman
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class LabelSmoothingCrossEntropy(nn.Module):
"""
NLL loss with label smoothing.
""" NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingCrossEntropy, self).__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing

def forward(self, x, target):
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
Expand All @@ -31,6 +31,6 @@ class SoftTargetCrossEntropy(nn.Module):
def __init__(self):
super(SoftTargetCrossEntropy, self).__init__()

def forward(self, x, target):
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
return loss.mean()
65 changes: 57 additions & 8 deletions timm/models/byoanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
A flexible network w/ dataclass based config for stacking NN blocks including
self-attention (or similar) layers.

Currently used to implement experimential variants of:
Currently used to implement experimental variants of:
* Bottleneck Transformers
* Lambda ResNets
* HaloNets
Expand All @@ -23,7 +23,7 @@
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
Expand All @@ -34,7 +34,7 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# GPU-Efficient (ResNet) weights
'botnet26t_256': _cfg(
url='',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'botnet50ts_256': _cfg(
url='',
Expand All @@ -46,19 +46,26 @@ def _cfg(url='', **kwargs):
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'sehalonet33ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'halonet50ts': _cfg(
url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_256_ra3-f07eab9f.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'eca_halonext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),

'lambda_resnet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_256-b040fce6.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_a2h_256-25ded63d.pth',
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'lambda_resnet50ts': _cfg(
url='',
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'lambda_resnet26rpt_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
}


Expand Down Expand Up @@ -198,6 +205,33 @@ def _cfg(url='', **kwargs):
self_attn_layer='lambda',
self_attn_kwargs=dict(r=9)
),
lambda_resnet50ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
act_layer='silu',
self_attn_layer='lambda',
self_attn_kwargs=dict(r=9)
),
lambda_resnet26rpt_256=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
self_attn_layer='lambda',
self_attn_kwargs=dict(r=None)
),
)


Expand Down Expand Up @@ -275,6 +309,21 @@ def eca_halonext26ts(pretrained=False, **kwargs):

@register_model
def lambda_resnet26t(pretrained=False, **kwargs):
""" Lambda-ResNet-26T. Lambda layers in last two stages.
""" Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages.
"""
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)


@register_model
def lambda_resnet50ts(pretrained=False, **kwargs):
""" Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages.
"""
return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs)


@register_model
def lambda_resnet26rpt_256(pretrained=False, **kwargs):
""" Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages.
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)
Loading