|
23 | 23 |
|
24 | 24 | Hacked together by / Copyright 2020, Ross Wightman
|
25 | 25 | """
|
| 26 | +import copy |
26 | 27 | import logging
|
27 | 28 | import math
|
28 | 29 | from collections import OrderedDict
|
@@ -1601,6 +1602,21 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
|
1601 | 1602 | hf_hub_filename='open_clip_pytorch_model.bin',
|
1602 | 1603 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
|
1603 | 1604 |
|
| 1605 | + 'vit_base_patch32_clip_224.laion400m_e32': _cfg( |
| 1606 | + hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', |
| 1607 | + notes=('natively QuickGELU, use quickgelu model variant for original results',), |
| 1608 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), |
| 1609 | + 'vit_base_patch16_clip_224.laion400m_e32': _cfg( |
| 1610 | + hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', |
| 1611 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
| 1612 | + 'vit_base_patch16_plus_clip_240.laion400m_e32': _cfg( |
| 1613 | + hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', |
| 1614 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, |
| 1615 | + input_size=(3, 240, 240), crop_pct=1.0, num_classes=512), |
| 1616 | + 'vit_large_patch14_clip_224.laion400m_e32': _cfg( |
| 1617 | + hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', |
| 1618 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), |
| 1619 | + |
1604 | 1620 | 'vit_base_patch32_clip_224.datacompxl': _cfg(
|
1605 | 1621 | hf_hub_id='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K',
|
1606 | 1622 | hf_hub_filename='open_clip_pytorch_model.bin',
|
@@ -1641,44 +1657,68 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
|
1641 | 1657 | crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
|
1642 | 1658 |
|
1643 | 1659 | 'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg(
|
1644 |
| - hf_hub_id='facebook/metaclip-b32-fullcc2.5b', |
1645 |
| - hf_hub_filename='metaclip_b32_fullcc2.5b.bin', |
| 1660 | + hf_hub_id='timm/', |
| 1661 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1646 | 1662 | license='cc-by-nc-4.0',
|
1647 | 1663 | notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
1648 | 1664 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
1649 | 1665 | 'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg(
|
1650 |
| - hf_hub_id='facebook/metaclip-b16-fullcc2.5b', |
1651 |
| - hf_hub_filename='metaclip_b16_fullcc2.5b.bin', |
| 1666 | + hf_hub_id='timm/', |
| 1667 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1652 | 1668 | license='cc-by-nc-4.0',
|
1653 | 1669 | notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
1654 | 1670 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
1655 | 1671 | 'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg(
|
1656 |
| - hf_hub_id='facebook/metaclip-l14-fullcc2.5b', |
1657 |
| - hf_hub_filename='metaclip_l14_fullcc2.5b.bin', |
| 1672 | + hf_hub_id='timm/', |
| 1673 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1658 | 1674 | license='cc-by-nc-4.0',
|
1659 | 1675 | notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
1660 | 1676 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
1661 | 1677 | 'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg(
|
1662 |
| - hf_hub_id='facebook/metaclip-h14-fullcc2.5b', |
1663 |
| - hf_hub_filename='metaclip_h14_fullcc2.5b.bin', |
| 1678 | + hf_hub_id='timm/', |
| 1679 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1664 | 1680 | license='cc-by-nc-4.0',
|
1665 | 1681 | notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
1666 | 1682 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
| 1683 | + 'vit_gigantic_patch14_clip_224.metaclip_2pt5b': _cfg( |
| 1684 | + hf_hub_id='timm/', |
| 1685 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1686 | + license='cc-by-nc-4.0', |
| 1687 | + notes=('natively QuickGELU, use quickgelu model variant for original results',), |
| 1688 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280), |
| 1689 | + 'vit_base_patch32_clip_224.metaclip_400m': _cfg( |
| 1690 | + hf_hub_id='timm/', |
| 1691 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1692 | + license='cc-by-nc-4.0', |
| 1693 | + notes=('natively QuickGELU, use quickgelu model variant for original results',), |
| 1694 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
| 1695 | + 'vit_base_patch16_clip_224.metaclip_400m': _cfg( |
| 1696 | + hf_hub_id='timm/', |
| 1697 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1698 | + license='cc-by-nc-4.0', |
| 1699 | + notes=('natively QuickGELU, use quickgelu model variant for original results',), |
| 1700 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), |
| 1701 | + 'vit_large_patch14_clip_224.metaclip_400m': _cfg( |
| 1702 | + hf_hub_id='timm/', |
| 1703 | + hf_hub_filename='open_clip_pytorch_model.bin', |
| 1704 | + license='cc-by-nc-4.0', |
| 1705 | + notes=('natively QuickGELU, use quickgelu model variant for original results',), |
| 1706 | + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), |
1667 | 1707 |
|
1668 | 1708 | 'vit_base_patch32_clip_224.openai': _cfg(
|
1669 |
| - hf_hub_id='timm/vit_base_patch32_clip_224.openai', |
| 1709 | + hf_hub_id='timm/', |
1670 | 1710 | notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
1671 | 1711 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
1672 | 1712 | 'vit_base_patch16_clip_224.openai': _cfg(
|
1673 |
| - hf_hub_id='timm/vit_base_patch16_clip_224.openai', |
| 1713 | + hf_hub_id='timm/', |
1674 | 1714 | notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
1675 | 1715 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
1676 | 1716 | 'vit_large_patch14_clip_224.openai': _cfg(
|
1677 |
| - hf_hub_id='timm/vit_large_patch14_clip_224.openai', |
| 1717 | + hf_hub_id='timm/', |
1678 | 1718 | notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
1679 | 1719 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
1680 | 1720 | 'vit_large_patch14_clip_336.openai': _cfg(
|
1681 |
| - hf_hub_id='timm/vit_large_patch14_clip_336.openai', hf_hub_filename='open_clip_pytorch_model.bin', |
| 1721 | + hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', |
1682 | 1722 | notes=('natively QuickGELU, use quickgelu model variant for original results',),
|
1683 | 1723 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
1684 | 1724 | crop_pct=1.0, input_size=(3, 336, 336), num_classes=768),
|
@@ -2071,22 +2111,13 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
|
2071 | 2111 | input_size=(3, 160, 160), crop_pct=0.95),
|
2072 | 2112 | }
|
2073 | 2113 |
|
2074 |
| -_quick_gelu_cfgs = [ |
2075 |
| - 'vit_large_patch14_clip_224.dfn2b', |
2076 |
| - 'vit_huge_patch14_clip_224.dfn5b', |
2077 |
| - 'vit_huge_patch14_clip_378.dfn5b', |
2078 |
| - 'vit_base_patch32_clip_224.metaclip_2pt5b', |
2079 |
| - 'vit_base_patch16_clip_224.metaclip_2pt5b', |
2080 |
| - 'vit_large_patch14_clip_224.metaclip_2pt5b', |
2081 |
| - 'vit_huge_patch14_clip_224.metaclip_2pt5b', |
2082 |
| - 'vit_base_patch32_clip_224.openai', |
2083 |
| - 'vit_base_patch16_clip_224.openai', |
2084 |
| - 'vit_large_patch14_clip_224.openai', |
2085 |
| - 'vit_large_patch14_clip_336.openai', |
2086 |
| -] |
2087 |
| -default_cfgs.update({ |
2088 |
| - n.replace('_clip_', '_clip_quickgelu_'): default_cfgs[n] for n in _quick_gelu_cfgs |
2089 |
| -}) |
| 2114 | +_quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]] |
| 2115 | +for n in _quick_gelu_cfgs: |
| 2116 | + # generate quickgelu default cfgs based on contents of notes field |
| 2117 | + c = copy.deepcopy(default_cfgs[n]) |
| 2118 | + if c['hf_hub_id'] == 'timm/': |
| 2119 | + c['hf_hub_id'] = 'timm/' + n # need to use non-quickgelu model name for hub id |
| 2120 | + default_cfgs[n.replace('_clip_', '_clip_quickgelu_')] = c |
2090 | 2121 | default_cfgs = generate_default_cfgs(default_cfgs)
|
2091 | 2122 |
|
2092 | 2123 |
|
@@ -2510,6 +2541,16 @@ def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTrans
|
2510 | 2541 | return model
|
2511 | 2542 |
|
2512 | 2543 |
|
| 2544 | +@register_model |
| 2545 | +def vit_base_patch16_plus_clip_240(pretrained: bool = False, **kwargs) -> VisionTransformer: |
| 2546 | + """ ViT-Base (ViT-B/16+) CLIP image tower @ 240x240 |
| 2547 | + """ |
| 2548 | + model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, pre_norm=True, norm_layer=nn.LayerNorm) |
| 2549 | + model = _create_vision_transformer( |
| 2550 | + 'vit_base_patch16_plus_clip_240', pretrained=pretrained, **dict(model_args, **kwargs)) |
| 2551 | + return model |
| 2552 | + |
| 2553 | + |
2513 | 2554 | @register_model
|
2514 | 2555 | def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
2515 | 2556 | """ ViT-Large model (ViT-L/14) CLIP image tower
|
@@ -2656,6 +2697,18 @@ def vit_huge_patch14_clip_quickgelu_378(pretrained: bool = False, **kwargs) -> V
|
2656 | 2697 | return model
|
2657 | 2698 |
|
2658 | 2699 |
|
| 2700 | +@register_model |
| 2701 | +def vit_gigantic_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: |
| 2702 | + """ ViT-bigG model (ViT-G/14) w/ QuickGELU act |
| 2703 | + """ |
| 2704 | + model_args = dict( |
| 2705 | + patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, |
| 2706 | + norm_layer=nn.LayerNorm, act_layer='quick_gelu') |
| 2707 | + model = _create_vision_transformer( |
| 2708 | + 'vit_gigantic_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| 2709 | + return model |
| 2710 | + |
| 2711 | + |
2659 | 2712 | # Experimental models below
|
2660 | 2713 |
|
2661 | 2714 | @register_model
|
|
0 commit comments