|
36 | 36 | AutoModelForCausalLM,
|
37 | 37 | AutoModelForSequenceClassification,
|
38 | 38 | PretrainedConfig,
|
| 39 | + PreTrainedModel, |
39 | 40 | is_torch_available,
|
40 | 41 | logging,
|
| 42 | + set_seed, |
41 | 43 | )
|
42 | 44 | from transformers.models.auto import get_values
|
43 | 45 | from transformers.models.auto.modeling_auto import (
|
|
85 | 87 | is_torch_fx_available,
|
86 | 88 | is_torch_sdpa_available,
|
87 | 89 | )
|
88 |
| -from transformers.utils.generic import ModelOutput |
| 90 | +from transformers.utils.generic import ContextManagers, ModelOutput |
89 | 91 |
|
90 | 92 |
|
91 | 93 | if is_accelerate_available():
|
|
99 | 101 | from torch import nn
|
100 | 102 |
|
101 | 103 | from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
| 104 | + from transformers.modeling_utils import no_init_weights |
102 | 105 | from transformers.pytorch_utils import id_tensor_storage
|
103 | 106 |
|
104 | 107 |
|
@@ -428,6 +431,56 @@ class CopyClass(model_class):
|
428 | 431 | max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
429 | 432 | self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
430 | 433 |
|
| 434 | + def test_fast_init_context_manager(self): |
| 435 | + # 1. Create a dummy class. Should have buffers as well? To make sure we test __init__ |
| 436 | + class MyClass(PreTrainedModel): |
| 437 | + config_class = PretrainedConfig |
| 438 | + |
| 439 | + def __init__(self, config=None): |
| 440 | + super().__init__(config if config is not None else PretrainedConfig()) |
| 441 | + self.linear = nn.Linear(10, 10, bias=True) |
| 442 | + self.embedding = nn.Embedding(10, 10) |
| 443 | + self.std = 1 |
| 444 | + |
| 445 | + def _init_weights(self, module): |
| 446 | + if isinstance(module, nn.Linear): |
| 447 | + module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5)) |
| 448 | + if module.bias is not None: |
| 449 | + module.bias.data.normal_(mean=0.0, std=self.std) |
| 450 | + |
| 451 | + # 2. Make sure a linear layer's reset params is properly skipped: |
| 452 | + with ContextManagers([no_init_weights(True)]): |
| 453 | + no_init_instance = MyClass() |
| 454 | + |
| 455 | + set_seed(0) |
| 456 | + expected_bias = torch.tensor( |
| 457 | + ([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475]) |
| 458 | + ) |
| 459 | + init_instance = MyClass() |
| 460 | + torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4) |
| 461 | + |
| 462 | + set_seed(0) |
| 463 | + torch.testing.assert_allclose( |
| 464 | + init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5)) |
| 465 | + ) |
| 466 | + |
| 467 | + # 3. Make sure weights that are not present use init_weight_ and get expected values |
| 468 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 469 | + state_dict = init_instance.state_dict() |
| 470 | + del state_dict["linear.weight"] |
| 471 | + |
| 472 | + init_instance.config.save_pretrained(tmpdirname) |
| 473 | + torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) |
| 474 | + set_seed(0) |
| 475 | + model_fast_init = MyClass.from_pretrained(tmpdirname) |
| 476 | + |
| 477 | + set_seed(0) |
| 478 | + model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False) |
| 479 | + |
| 480 | + for key in model_fast_init.state_dict().keys(): |
| 481 | + max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])) |
| 482 | + self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical") |
| 483 | + |
431 | 484 | def test_save_load_fast_init_to_base(self):
|
432 | 485 | config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
433 | 486 | if config.__class__ not in MODEL_MAPPING:
|
|
0 commit comments