Skip to content

Commit cd638d5

Browse files
authored
Merge pull request #880 from rwightman/fixes_bce_regnet
A collection of fixes, model experiments, etc
2 parents b5bf4dc + 93901e9 commit cd638d5

18 files changed

+425
-205
lines changed

avg_checkpoints.py

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def checkpoint_metric(checkpoint_path):
4141
metric = None
4242
if 'metric' in checkpoint:
4343
metric = checkpoint['metric']
44+
elif 'metrics' in checkpoint and 'metric_name' in checkpoint:
45+
metrics = checkpoint['metrics']
46+
print(metrics)
47+
metric = metrics[checkpoint['metric_name']]
4448
return metric
4549

4650

timm/data/loader.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
44
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
55
6-
Hacked together by / Copyright 2020 Ross Wightman
6+
Hacked together by / Copyright 2021 Ross Wightman
77
"""
8+
import random
9+
from functools import partial
10+
from typing import Callable
811

912
import torch.utils.data
1013
import numpy as np
@@ -125,6 +128,22 @@ def mixup_enabled(self, x):
125128
self.loader.collate_fn.mixup_enabled = x
126129

127130

131+
def _worker_init(worker_id, worker_seeding='all'):
132+
worker_info = torch.utils.data.get_worker_info()
133+
assert worker_info.id == worker_id
134+
if isinstance(worker_seeding, Callable):
135+
seed = worker_seeding(worker_info)
136+
random.seed(seed)
137+
torch.manual_seed(seed)
138+
np.random.seed(seed % (2 ** 32 - 1))
139+
else:
140+
assert worker_seeding in ('all', 'part')
141+
# random / torch seed already called in dataloader iter class w/ worker_info.seed
142+
# to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
143+
if worker_seeding == 'all':
144+
np.random.seed(worker_info.seed % (2 ** 32 - 1))
145+
146+
128147
def create_loader(
129148
dataset,
130149
input_size,
@@ -156,6 +175,7 @@ def create_loader(
156175
tf_preprocessing=False,
157176
use_multi_epochs_loader=False,
158177
persistent_workers=True,
178+
worker_seeding='all',
159179
):
160180
re_num_splits = 0
161181
if re_split:
@@ -202,7 +222,6 @@ def create_loader(
202222
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
203223

204224
loader_class = torch.utils.data.DataLoader
205-
206225
if use_multi_epochs_loader:
207226
loader_class = MultiEpochsDataLoader
208227

@@ -214,7 +233,9 @@ def create_loader(
214233
collate_fn=collate_fn,
215234
pin_memory=pin_memory,
216235
drop_last=is_training,
217-
persistent_workers=persistent_workers)
236+
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
237+
persistent_workers=persistent_workers
238+
)
218239
try:
219240
loader = loader_class(dataset, **loader_args)
220241
except TypeError as e:

timm/loss/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
2-
from .binary_cross_entropy import DenseBinaryCrossEntropy
2+
from .binary_cross_entropy import BinaryCrossEntropy
33
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
44
from .jsd import JsdCrossEntropy

timm/loss/binary_cross_entropy.py

+37-13
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,47 @@
1+
""" Binary Cross Entropy w/ a few extras
2+
3+
Hacked together by / Copyright 2021 Ross Wightman
4+
"""
5+
from typing import Optional
6+
17
import torch
28
import torch.nn as nn
39
import torch.nn.functional as F
410

511

6-
class DenseBinaryCrossEntropy(nn.Module):
7-
""" BCE using one-hot from dense targets w/ label smoothing
12+
class BinaryCrossEntropy(nn.Module):
13+
""" BCE with optional one-hot from dense targets, label smoothing, thresholding
814
NOTE for experiments comparing CE to BCE /w label smoothing, may remove
915
"""
10-
def __init__(self, smoothing=0.1):
11-
super(DenseBinaryCrossEntropy, self).__init__()
16+
def __init__(
17+
self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None,
18+
reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None):
19+
super(BinaryCrossEntropy, self).__init__()
1220
assert 0. <= smoothing < 1.0
1321
self.smoothing = smoothing
14-
self.bce = nn.BCEWithLogitsLoss()
22+
self.target_threshold = target_threshold
23+
self.reduction = reduction
24+
self.register_buffer('weight', weight)
25+
self.register_buffer('pos_weight', pos_weight)
1526

16-
def forward(self, x, target):
17-
num_classes = x.shape[-1]
18-
off_value = self.smoothing / num_classes
19-
on_value = 1. - self.smoothing + off_value
20-
target = target.long().view(-1, 1)
21-
target = torch.full(
22-
(target.size()[0], num_classes), off_value, device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
23-
return self.bce(x, target)
27+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
28+
assert x.shape[0] == target.shape[0]
29+
if target.shape != x.shape:
30+
# NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
31+
num_classes = x.shape[-1]
32+
# FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ
33+
off_value = self.smoothing / num_classes
34+
on_value = 1. - self.smoothing + off_value
35+
target = target.long().view(-1, 1)
36+
target = torch.full(
37+
(target.size()[0], num_classes),
38+
off_value,
39+
device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
40+
if self.target_threshold is not None:
41+
# Make target 0, or 1 if threshold set
42+
target = target.gt(self.target_threshold).to(dtype=target.dtype)
43+
return F.binary_cross_entropy_with_logits(
44+
x, target,
45+
self.weight,
46+
pos_weight=self.pos_weight,
47+
reduction=self.reduction)

timm/loss/cross_entropy.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1+
""" Cross Entropy w/ smoothing or soft targets
2+
3+
Hacked together by / Copyright 2021 Ross Wightman
4+
"""
5+
16
import torch
27
import torch.nn as nn
38
import torch.nn.functional as F
49

510

611
class LabelSmoothingCrossEntropy(nn.Module):
7-
"""
8-
NLL loss with label smoothing.
12+
""" NLL loss with label smoothing.
913
"""
1014
def __init__(self, smoothing=0.1):
11-
"""
12-
Constructor for the LabelSmoothing module.
13-
:param smoothing: label smoothing factor
14-
"""
1515
super(LabelSmoothingCrossEntropy, self).__init__()
1616
assert smoothing < 1.0
1717
self.smoothing = smoothing
1818
self.confidence = 1. - smoothing
1919

20-
def forward(self, x, target):
20+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
2121
logprobs = F.log_softmax(x, dim=-1)
2222
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
2323
nll_loss = nll_loss.squeeze(1)
@@ -31,6 +31,6 @@ class SoftTargetCrossEntropy(nn.Module):
3131
def __init__(self):
3232
super(SoftTargetCrossEntropy, self).__init__()
3333

34-
def forward(self, x, target):
34+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
3535
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
3636
return loss.mean()

timm/models/byoanet.py

+57-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
A flexible network w/ dataclass based config for stacking NN blocks including
44
self-attention (or similar) layers.
55
6-
Currently used to implement experimential variants of:
6+
Currently used to implement experimental variants of:
77
* Bottleneck Transformers
88
* Lambda ResNets
99
* HaloNets
@@ -23,7 +23,7 @@
2323
def _cfg(url='', **kwargs):
2424
return {
2525
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
26-
'crop_pct': 0.875, 'interpolation': 'bicubic',
26+
'crop_pct': 0.95, 'interpolation': 'bicubic',
2727
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
2828
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
2929
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
@@ -34,7 +34,7 @@ def _cfg(url='', **kwargs):
3434
default_cfgs = {
3535
# GPU-Efficient (ResNet) weights
3636
'botnet26t_256': _cfg(
37-
url='',
37+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
3838
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
3939
'botnet50ts_256': _cfg(
4040
url='',
@@ -46,19 +46,26 @@ def _cfg(url='', **kwargs):
4646
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
4747
'halonet26t': _cfg(
4848
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth',
49-
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
49+
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
5050
'sehalonet33ts': _cfg(
5151
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
5252
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
5353
'halonet50ts': _cfg(
54-
url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
54+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_256_ra3-f07eab9f.pth',
55+
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
5556
'eca_halonext26ts': _cfg(
5657
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth',
57-
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
58+
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
5859

5960
'lambda_resnet26t': _cfg(
60-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_256-b040fce6.pth',
61+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_a2h_256-25ded63d.pth',
62+
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
63+
'lambda_resnet50ts': _cfg(
64+
url='',
6165
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
66+
'lambda_resnet26rpt_256': _cfg(
67+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth',
68+
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
6269
}
6370

6471

@@ -198,6 +205,33 @@ def _cfg(url='', **kwargs):
198205
self_attn_layer='lambda',
199206
self_attn_kwargs=dict(r=9)
200207
),
208+
lambda_resnet50ts=ByoModelCfg(
209+
blocks=(
210+
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
211+
interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
212+
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
213+
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
214+
),
215+
stem_chs=64,
216+
stem_type='tiered',
217+
stem_pool='maxpool',
218+
act_layer='silu',
219+
self_attn_layer='lambda',
220+
self_attn_kwargs=dict(r=9)
221+
),
222+
lambda_resnet26rpt_256=ByoModelCfg(
223+
blocks=(
224+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
225+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
226+
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
227+
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
228+
),
229+
stem_chs=64,
230+
stem_type='tiered',
231+
stem_pool='maxpool',
232+
self_attn_layer='lambda',
233+
self_attn_kwargs=dict(r=None)
234+
),
201235
)
202236

203237

@@ -275,6 +309,21 @@ def eca_halonext26ts(pretrained=False, **kwargs):
275309

276310
@register_model
277311
def lambda_resnet26t(pretrained=False, **kwargs):
278-
""" Lambda-ResNet-26T. Lambda layers in last two stages.
312+
""" Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages.
279313
"""
280314
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
315+
316+
317+
@register_model
318+
def lambda_resnet50ts(pretrained=False, **kwargs):
319+
""" Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages.
320+
"""
321+
return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs)
322+
323+
324+
@register_model
325+
def lambda_resnet26rpt_256(pretrained=False, **kwargs):
326+
""" Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages.
327+
"""
328+
kwargs.setdefault('img_size', 256)
329+
return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)

0 commit comments

Comments
 (0)