Skip to content

Commit 0676d99

Browse files
[from_pretrained] Make from_pretrained fast again (#27709)
* Skip nn.Module.reset_parameters * Actually skip * Check quality * Maybe change all inits * Fix init issues: only modify public functions * Add a small test for now * Style * test updates * style * nice tes * style * make it even faster * one more second * remove fx icompatible * Update tests/test_modeling_common.py Co-authored-by: Lysandre Debut <[email protected]> * Update tests/test_modeling_common.py Co-authored-by: Lysandre Debut <[email protected]> * skip * fix quality * protect the import --------- Co-authored-by: Lysandre Debut <[email protected]>
1 parent 9f18cc6 commit 0676d99

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

src/transformers/modeling_utils.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,23 @@ def is_local_dist_rank_0():
154154
if is_peft_available():
155155
from .utils import find_adapter_config_file
156156

157+
TORCH_INIT_FUNCTIONS = {
158+
"uniform_": nn.init.uniform_,
159+
"normal_": nn.init.normal_,
160+
"trunc_normal_": nn.init.trunc_normal_,
161+
"constant_": nn.init.constant_,
162+
"xavier_uniform_": nn.init.xavier_uniform_,
163+
"xavier_normal_": nn.init.xavier_normal_,
164+
"kaiming_uniform_": nn.init.kaiming_uniform_,
165+
"kaiming_normal_": nn.init.kaiming_normal_,
166+
"uniform": nn.init.uniform,
167+
"normal": nn.init.normal,
168+
"xavier_uniform": nn.init.xavier_uniform,
169+
"xavier_normal": nn.init.xavier_normal,
170+
"kaiming_uniform": nn.init.kaiming_uniform,
171+
"kaiming_normal": nn.init.kaiming_normal,
172+
}
173+
157174

158175
@contextmanager
159176
def no_init_weights(_enable=True):
@@ -164,12 +181,24 @@ def no_init_weights(_enable=True):
164181
"""
165182
global _init_weights
166183
old_init_weights = _init_weights
184+
167185
if _enable:
168186
_init_weights = False
187+
188+
def _skip_init(*args, **kwargs):
189+
pass
190+
191+
# # Save the original initialization functions
192+
for name, init_func in TORCH_INIT_FUNCTIONS.items():
193+
setattr(torch.nn.init, name, _skip_init)
169194
try:
170195
yield
171196
finally:
172197
_init_weights = old_init_weights
198+
if _enable:
199+
# # Restore the original initialization functions
200+
for name, init_func in TORCH_INIT_FUNCTIONS.items():
201+
setattr(torch.nn.init, name, init_func)
173202

174203

175204
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
@@ -1506,7 +1535,10 @@ def get_output_embeddings(self) -> nn.Module:
15061535

15071536
def _init_weights(self, module):
15081537
"""
1509-
Initialize the weights. This method should be overridden by derived class.
1538+
Initialize the weights. This method should be overridden by derived class and is
1539+
the only initialization method that will be called when loading a checkpoint
1540+
using `from_pretrained`. Any attempt to initialize outside of this function
1541+
will be useless as the torch.nn.init function are all replaced with skip.
15101542
"""
15111543
pass
15121544

@@ -3414,6 +3446,7 @@ def from_pretrained(
34143446
)
34153447

34163448
with ContextManagers(init_contexts):
3449+
# Let's make sure we don't run the init function of buffer modules
34173450
model = cls(config, *model_args, **model_kwargs)
34183451

34193452
# make sure we use the model's config since the __init__ call might have copied it

tests/test_modeling_common.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@
3636
AutoModelForCausalLM,
3737
AutoModelForSequenceClassification,
3838
PretrainedConfig,
39+
PreTrainedModel,
3940
is_torch_available,
4041
logging,
42+
set_seed,
4143
)
4244
from transformers.models.auto import get_values
4345
from transformers.models.auto.modeling_auto import (
@@ -85,7 +87,7 @@
8587
is_torch_fx_available,
8688
is_torch_sdpa_available,
8789
)
88-
from transformers.utils.generic import ModelOutput
90+
from transformers.utils.generic import ContextManagers, ModelOutput
8991

9092

9193
if is_accelerate_available():
@@ -99,6 +101,7 @@
99101
from torch import nn
100102

101103
from transformers import MODEL_MAPPING, AdaptiveEmbedding
104+
from transformers.modeling_utils import no_init_weights
102105
from transformers.pytorch_utils import id_tensor_storage
103106

104107

@@ -428,6 +431,56 @@ class CopyClass(model_class):
428431
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
429432
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
430433

434+
def test_fast_init_context_manager(self):
435+
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
436+
class MyClass(PreTrainedModel):
437+
config_class = PretrainedConfig
438+
439+
def __init__(self, config=None):
440+
super().__init__(config if config is not None else PretrainedConfig())
441+
self.linear = nn.Linear(10, 10, bias=True)
442+
self.embedding = nn.Embedding(10, 10)
443+
self.std = 1
444+
445+
def _init_weights(self, module):
446+
if isinstance(module, nn.Linear):
447+
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
448+
if module.bias is not None:
449+
module.bias.data.normal_(mean=0.0, std=self.std)
450+
451+
# 2. Make sure a linear layer's reset params is properly skipped:
452+
with ContextManagers([no_init_weights(True)]):
453+
no_init_instance = MyClass()
454+
455+
set_seed(0)
456+
expected_bias = torch.tensor(
457+
([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475])
458+
)
459+
init_instance = MyClass()
460+
torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4)
461+
462+
set_seed(0)
463+
torch.testing.assert_allclose(
464+
init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5))
465+
)
466+
467+
# 3. Make sure weights that are not present use init_weight_ and get expected values
468+
with tempfile.TemporaryDirectory() as tmpdirname:
469+
state_dict = init_instance.state_dict()
470+
del state_dict["linear.weight"]
471+
472+
init_instance.config.save_pretrained(tmpdirname)
473+
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
474+
set_seed(0)
475+
model_fast_init = MyClass.from_pretrained(tmpdirname)
476+
477+
set_seed(0)
478+
model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False)
479+
480+
for key in model_fast_init.state_dict().keys():
481+
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
482+
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
483+
431484
def test_save_load_fast_init_to_base(self):
432485
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
433486
if config.__class__ not in MODEL_MAPPING:

0 commit comments

Comments
 (0)