Skip to content

Commit 6a01101

Browse files
committed
Update efficientnet.py and convnext.py to multi-weight, add ImageNet-12k pretrained EfficientNet-B5 and ConvNeXt-Nano.
1 parent e7da205 commit 6a01101

File tree

7 files changed

+686
-729
lines changed

7 files changed

+686
-729
lines changed

timm/models/_builder.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import logging
3+
import os
34
from copy import deepcopy
45
from typing import Optional, Dict, Callable, Any, Tuple
56

@@ -9,7 +10,7 @@
910
from timm.models._features import FeatureListNet, FeatureHookNet
1011
from timm.models._features_fx import FeatureGraphNet
1112
from timm.models._helpers import load_state_dict
12-
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
13+
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
1314
from timm.models._manipulate import adapt_input_conv
1415
from timm.models._pretrained import PretrainedCfg
1516
from timm.models._prune import adapt_model_from_file
@@ -32,6 +33,7 @@ def _resolve_pretrained_source(pretrained_cfg):
3233
pretrained_url = pretrained_cfg.get('url', None)
3334
pretrained_file = pretrained_cfg.get('file', None)
3435
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
36+
3537
# resolve where to load pretrained weights from
3638
load_from = ''
3739
pretrained_loc = ''
@@ -43,15 +45,20 @@ def _resolve_pretrained_source(pretrained_cfg):
4345
else:
4446
# default source == timm or unspecified
4547
if pretrained_file:
48+
# file load override is the highest priority if set
4649
load_from = 'file'
4750
pretrained_loc = pretrained_file
48-
elif pretrained_url:
49-
load_from = 'url'
50-
pretrained_loc = pretrained_url
51-
elif hf_hub_id and has_hf_hub(necessary=True):
52-
# hf-hub available as alternate weight source in default_cfg
53-
load_from = 'hf-hub'
54-
pretrained_loc = hf_hub_id
51+
else:
52+
# next, HF hub is prioritized unless a valid cached version of weights exists already
53+
cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False
54+
if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid:
55+
# hf-hub available as alternate weight source in default_cfg
56+
load_from = 'hf-hub'
57+
pretrained_loc = hf_hub_id
58+
elif pretrained_url:
59+
load_from = 'url'
60+
pretrained_loc = pretrained_url
61+
5562
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
5663
# if a filename override is set, return tuple for location w/ (hub_id, filename)
5764
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
@@ -105,7 +112,7 @@ def load_custom_pretrained(
105112
pretrained_loc = download_cached_file(
106113
pretrained_loc,
107114
check_hash=_CHECK_HASH,
108-
progress=_DOWNLOAD_PROGRESS
115+
progress=_DOWNLOAD_PROGRESS,
109116
)
110117

111118
if load_fn is not None:

timm/models/_hub.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import json
23
import logging
34
import os
@@ -67,6 +68,26 @@ def download_cached_file(url, check_hash=True, progress=False):
6768
return cached_file
6869

6970

71+
def check_cached_file(url, check_hash=True):
72+
if isinstance(url, (list, tuple)):
73+
url, filename = url
74+
else:
75+
parts = urlparse(url)
76+
filename = os.path.basename(parts.path)
77+
cached_file = os.path.join(get_cache_dir(), filename)
78+
if os.path.exists(cached_file):
79+
if check_hash:
80+
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
81+
hash_prefix = r.group(1) if r else None
82+
if hash_prefix:
83+
with open(cached_file, 'rb') as f:
84+
hd = hashlib.sha256(f.read()).hexdigest()
85+
if hd[:len(hash_prefix)] != hash_prefix:
86+
return False
87+
return True
88+
return False
89+
90+
7091
def has_hf_hub(necessary=False):
7192
if not _has_hf_hub and necessary:
7293
# if no HF Hub module installed, and it is necessary to continue, raise error
@@ -145,7 +166,9 @@ def save_for_hf(model, save_directory, model_config=None):
145166
hf_config['architecture'] = pretrained_cfg.pop('architecture')
146167
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
147168
hf_config['num_features'] = model_config.get('num_features', model.num_features)
148-
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
169+
global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
170+
if isinstance(global_pool_type, str) and global_pool_type:
171+
hf_config['global_pool'] = global_pool_type
149172

150173
if 'label' in model_config:
151174
_logger.warning(

timm/models/_pretrained.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class PretrainedCfg:
1919

2020
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
2121
architecture: Optional[str] = None # architecture variant can be set when not implicit
22+
tag: Optional[str] = None # pretrained tag of source
2223
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
2324

2425
# input / data config

timm/models/_registry.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
from collections import defaultdict, deque
99
from copy import deepcopy
10+
from dataclasses import replace
1011
from typing import List, Optional, Union, Tuple
1112

1213
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@@ -20,7 +21,7 @@
2021
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
2122
_model_has_pretrained = set() # set of model names that have pretrained weight url present
2223
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
23-
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
24+
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
2425
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
2526

2627

@@ -48,24 +49,31 @@ def register_model(fn):
4849
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
4950
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
5051
# entrypoints or non-matching combos
51-
cfg = mod.default_cfgs[model_name]
52-
if not isinstance(cfg, DefaultCfg):
52+
default_cfg = mod.default_cfgs[model_name]
53+
if not isinstance(default_cfg, DefaultCfg):
5354
# new style default cfg dataclass w/ multiple entries per model-arch
54-
assert isinstance(cfg, dict)
55+
assert isinstance(default_cfg, dict)
5556
# old style cfg dict per model-arch
56-
cfg = PretrainedCfg(**cfg)
57-
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
57+
pretrained_cfg = PretrainedCfg(**default_cfg)
58+
default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg})
5859

59-
for tag_idx, tag in enumerate(cfg.tags):
60+
for tag_idx, tag in enumerate(default_cfg.tags):
6061
is_default = tag_idx == 0
61-
pretrained_cfg = cfg.cfgs[tag]
62+
pretrained_cfg = default_cfg.cfgs[tag]
63+
model_name_tag = '.'.join([model_name, tag]) if tag else model_name
64+
replace_items = dict(architecture=model_name, tag=tag if tag else None)
65+
if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/':
66+
# auto-complete hub name w/ architecture.tag
67+
replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag
68+
pretrained_cfg = replace(pretrained_cfg, **replace_items)
69+
6270
if is_default:
6371
_model_pretrained_cfgs[model_name] = pretrained_cfg
6472
if pretrained_cfg.has_weights:
6573
# add tagless entry if it's default and has weights
6674
_model_has_pretrained.add(model_name)
75+
6776
if tag:
68-
model_name_tag = '.'.join([model_name, tag])
6977
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
7078
if pretrained_cfg.has_weights:
7179
# add model w/ tag if tag is valid
@@ -74,7 +82,7 @@ def register_model(fn):
7482
else:
7583
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
7684

77-
_model_default_cfgs[model_name] = cfg
85+
_model_default_cfgs[model_name] = default_cfg
7886

7987
return fn
8088

timm/models/convnext.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ def _create_convnext(variant, pretrained=False, **kwargs):
361361
return model
362362

363363

364-
365364
def _cfg(url='', **kwargs):
366365
return {
367366
'url': url,
@@ -375,90 +374,130 @@ def _cfg(url='', **kwargs):
375374

376375
default_cfgs = generate_default_cfgs({
377376
# timm specific variants
378-
'convnext_atto.timm_in1k': _cfg(
377+
'convnext_atto.d2_in1k': _cfg(
379378
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
379+
hf_hub_id='timm/',
380380
test_input_size=(3, 288, 288), test_crop_pct=0.95),
381-
'convnext_atto_ols.timm_in1k': _cfg(
381+
'convnext_atto_ols.a2_in1k': _cfg(
382382
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
383+
hf_hub_id='timm/',
383384
test_input_size=(3, 288, 288), test_crop_pct=0.95),
384-
'convnext_femto.timm_in1k': _cfg(
385+
'convnext_femto.d1_in1k': _cfg(
385386
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
387+
hf_hub_id='timm/',
386388
test_input_size=(3, 288, 288), test_crop_pct=0.95),
387-
'convnext_femto_ols.timm_in1k': _cfg(
389+
'convnext_femto_ols.d1_in1k': _cfg(
388390
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
391+
hf_hub_id='timm/',
389392
test_input_size=(3, 288, 288), test_crop_pct=0.95),
390-
'convnext_pico.timm_in1k': _cfg(
393+
'convnext_pico.d1_in1k': _cfg(
391394
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
395+
hf_hub_id='timm/',
392396
test_input_size=(3, 288, 288), test_crop_pct=0.95),
393-
'convnext_pico_ols.timm_in1k': _cfg(
397+
'convnext_pico_ols.d1_in1k': _cfg(
394398
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
399+
hf_hub_id='timm/',
400+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
401+
'convnext_nano.in12k_ft_in1k': _cfg(
402+
hf_hub_id='timm/',
395403
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
396-
'convnext_nano.timm_in1k': _cfg(
404+
'convnext_nano.d1h_in1k': _cfg(
397405
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
406+
hf_hub_id='timm/',
398407
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
399-
'convnext_nano_ols.timm_in1k': _cfg(
408+
'convnext_nano_ols.d1h_in1k': _cfg(
400409
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
410+
hf_hub_id='timm/',
401411
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
402-
'convnext_tiny_hnf.timm_in1k': _cfg(
412+
'convnext_tiny_hnf.a2h_in1k': _cfg(
403413
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
414+
hf_hub_id='timm/',
404415
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
405416

417+
'convnext_nano.in12k': _cfg(
418+
hf_hub_id='timm/',
419+
crop_pct=0.95, num_classes=11821),
420+
406421
'convnext_tiny.fb_in1k': _cfg(
407422
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
423+
hf_hub_id='timm/',
408424
test_input_size=(3, 288, 288), test_crop_pct=1.0),
409425
'convnext_small.fb_in1k': _cfg(
410426
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
427+
hf_hub_id='timm/',
411428
test_input_size=(3, 288, 288), test_crop_pct=1.0),
412429
'convnext_base.fb_in1k': _cfg(
413430
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
431+
hf_hub_id='timm/',
414432
test_input_size=(3, 288, 288), test_crop_pct=1.0),
415433
'convnext_large.fb_in1k': _cfg(
416434
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
435+
hf_hub_id='timm/',
417436
test_input_size=(3, 288, 288), test_crop_pct=1.0),
418437
'convnext_xlarge.untrained': _cfg(),
419438

420439
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
421440
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
441+
hf_hub_id='timm/',
422442
test_input_size=(3, 288, 288), test_crop_pct=1.0),
423443
'convnext_small.fb_in22k_ft_in1k': _cfg(
424444
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
445+
hf_hub_id='timm/',
425446
test_input_size=(3, 288, 288), test_crop_pct=1.0),
426447
'convnext_base.fb_in22k_ft_in1k': _cfg(
427448
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
449+
hf_hub_id='timm/',
428450
test_input_size=(3, 288, 288), test_crop_pct=1.0),
429451
'convnext_large.fb_in22k_ft_in1k': _cfg(
430452
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
453+
hf_hub_id='timm/',
431454
test_input_size=(3, 288, 288), test_crop_pct=1.0),
432455
'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
433456
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
457+
hf_hub_id='timm/',
434458
test_input_size=(3, 288, 288), test_crop_pct=1.0),
435459

436460
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
437461
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
462+
hf_hub_id='timm/',
438463
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
439-
'convnext_small..fb_in22k_ft_in1k_384': _cfg(
464+
'convnext_small.fb_in22k_ft_in1k_384': _cfg(
440465
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
466+
hf_hub_id='timm/',
441467
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
442468
'convnext_base.fb_in22k_ft_in1k_384': _cfg(
443469
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
470+
hf_hub_id='timm/',
444471
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
445472
'convnext_large.fb_in22k_ft_in1k_384': _cfg(
446473
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
474+
hf_hub_id='timm/',
447475
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
448476
'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
449477
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
478+
hf_hub_id='timm/',
450479
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
451480

452-
'convnext_tiny_in22k.fb_in22k': _cfg(
453-
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
454-
'convnext_small_in22k.fb_in22k': _cfg(
455-
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
456-
'convnext_base_in22k.fb_in22k': _cfg(
457-
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
458-
'convnext_large_in22k.fb_in22k': _cfg(
459-
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
460-
'convnext_xlarge_in22k.fb_in22k': _cfg(
461-
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
481+
'convnext_tiny.fb_in22k': _cfg(
482+
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
483+
hf_hub_id='timm/',
484+
num_classes=21841),
485+
'convnext_small.fb_in22k': _cfg(
486+
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
487+
hf_hub_id='timm/',
488+
num_classes=21841),
489+
'convnext_base.fb_in22k': _cfg(
490+
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
491+
hf_hub_id='timm/',
492+
num_classes=21841),
493+
'convnext_large.fb_in22k': _cfg(
494+
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
495+
hf_hub_id='timm/',
496+
num_classes=21841),
497+
'convnext_xlarge.fb_in22k': _cfg(
498+
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
499+
hf_hub_id='timm/',
500+
num_classes=21841),
462501
})
463502

464503

0 commit comments

Comments
 (0)