Skip to content

Update efficientnet.py and convnext.py to multi-weight, add new 12k pretrained weights #1593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ output/
*.tar
*.pth
*.pt
*.torch
*.gz
Untitled.ipynb
Testing notebook.ipynb

# Root dir exclusions
/*.csv
/*.yaml
/*.json
/*.jpg
/*.png
/*.zip
/*.tar.*
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*'
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

Expand Down
2 changes: 1 addition & 1 deletion timm/data/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def create_dataset(
elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, reader=name, split=split, **kwargs)
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
elif name.startswith('tfds/'):
ds = IterableImageDataset(
root,
Expand Down
2 changes: 1 addition & 1 deletion timm/data/readers/reader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def create_reader(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
name = name.split('/', 1)
prefix = ''
if len(name) > 1:
prefix = name[0]
Expand Down
20 changes: 15 additions & 5 deletions timm/data/readers/reader_hfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.")
exit(1)
from .class_map import load_class_map
from .reader import Reader


def get_class_labels(info):
def get_class_labels(info, label_key='label'):
if 'label' not in info.features:
return {}
class_label = info.features['label']
class_label = info.features[label_key]
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
return class_to_idx

Expand All @@ -32,6 +33,7 @@ def __init__(
name,
split='train',
class_map=None,
label_key='label',
download=False,
):
"""
Expand All @@ -43,12 +45,17 @@ def __init__(
name, # 'name' maps to path arg in hf datasets
split=split,
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
#use_auth_token=True,
)
# leave decode for caller, plus we want easy access to original path names...
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))

self.class_to_idx = get_class_labels(self.dataset.info)
self.label_key = label_key
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
self.split_info = self.dataset.info.splits[split]
self.num_samples = self.split_info.num_examples

Expand All @@ -60,7 +67,10 @@ def __getitem__(self, index):
else:
assert 'path' in image and image['path']
image = open(image['path'], 'rb')
return image, item['label']
label = item[self.label_key]
if self.remap_class:
label = self.class_to_idx[label]
return image, label

def __len__(self):
return len(self.dataset)
Expand Down
7 changes: 6 additions & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer
Expand Down Expand Up @@ -30,8 +31,12 @@
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .patch_embed import PatchEmbed, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \
FourierEmbed, RotaryEmbedding
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct
Expand Down
2 changes: 1 addition & 1 deletion timm/layers/attention_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn as nn

from .helpers import to_2tuple
from .pos_embed import apply_rot_embed, RotaryEmbedding
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
from .weight_init import trunc_normal_


Expand Down
2 changes: 1 addition & 1 deletion timm/layers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(x)
return tuple(repeat(x, n))
return parse

Expand Down
130 changes: 129 additions & 1 deletion timm/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,24 @@

A convolution based approach to patchifying a 2D image w/ embedding projection.

Based on the impl in https://github.com/google-research/vision_transformer
Based on code in:
* https://github.com/google-research/vision_transformer
* https://github.com/google-research/big_vision/tree/main/big_vision

Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
from typing import List

import torch
from torch import nn as nn
import torch.nn.functional as F

from .helpers import to_2tuple
from .trace_utils import _assert

_logger = logging.getLogger(__name__)


class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
Expand Down Expand Up @@ -46,3 +55,122 @@ def forward(self, x):
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x


def resample_patch_embed(
patch_embed,
new_size: List[int],
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
"""Resample the weights of the patch embedding kernel to target resolution.
We resample the patch embedding kernel by approximately inverting the effect
of patch resizing.

Code based on:
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py

With this resizing, we can for example load a B/8 filter into a B/16 model
and, on 2x larger input image, the result will match.

Args:
patch_embed: original parameter to be resized.
new_size (tuple(int, int): target shape (height, width)-only.
interpolation (str): interpolation for resize
antialias (bool): use anti-aliasing filter in resize
verbose (bool): log operation
Returns:
Resized patch embedding kernel.
"""
import numpy as np

assert len(patch_embed.shape) == 4, "Four dimensions expected"
assert len(new_size) == 2, "New shape should only be hw"
old_size = patch_embed.shape[-2:]
if tuple(old_size) == tuple(new_size):
return patch_embed

if verbose:
_logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")

def resize(x_np, _new_size):
x_tf = torch.Tensor(x_np)[None, None, ...]
x_upsampled = F.interpolate(
x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
return x_upsampled

def get_resize_mat(_old_size, _new_size):
mat = []
for i in range(np.prod(_old_size)):
basis_vec = np.zeros(_old_size)
basis_vec[np.unravel_index(i, _old_size)] = 1.
mat.append(resize(basis_vec, _new_size).reshape(-1))
return np.stack(mat).T

resize_mat = get_resize_mat(old_size, new_size)
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))

def resample_kernel(kernel):
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
return resampled_kernel.reshape(new_size)

v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1)
return v_resample_kernel(patch_embed)


# def divs(n, m=None):
# m = m or n // 2
# if m == 1:
# return [1]
# if n % m == 0:
# return [m] + divs(n, m - 1)
# return divs(n, m - 1)
#
#
# class FlexiPatchEmbed(nn.Module):
# """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT)
# FIXME WIP
# """
# def __init__(
# self,
# img_size=240,
# patch_size=16,
# in_chans=3,
# embed_dim=768,
# base_img_size=240,
# base_patch_size=32,
# norm_layer=None,
# flatten=True,
# bias=True,
# ):
# super().__init__()
# self.img_size = to_2tuple(img_size)
# self.patch_size = to_2tuple(patch_size)
# self.num_patches = 0
#
# # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48)
# self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30)
#
# self.base_img_size = to_2tuple(base_img_size)
# self.base_patch_size = to_2tuple(base_patch_size)
# self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)])
# self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1]
#
# self.flatten = flatten
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias)
# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
#
# def forward(self, x):
# B, C, H, W = x.shape
#
# if self.patch_size == self.base_patch_size:
# weight = self.proj.weight
# else:
# weight = resample_patch_embed(self.proj.weight, self.patch_size)
# patch_size = self.patch_size
# x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
# if self.flatten:
# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
# x = self.norm(x)
# return x
Loading