21
21
from collections import defaultdict
22
22
from contextlib import contextmanager
23
23
from typing import Dict , List , Optional , Tuple , Union
24
- from huggingface_hub import model_info
25
-
26
- import transformers
27
- from transformers import AutoConfig , AutoModel
28
-
29
- import torch
30
- import torch .nn as nn
31
24
32
25
logger = logging .getLogger (__name__ )
33
26
@@ -40,8 +33,54 @@ class CustomDtype(enum.Enum):
40
33
INT2 = "int2"
41
34
42
35
36
+ def import_torch_nn ():
37
+ """Import torch.nn"""
38
+ try :
39
+ import torch .nn
40
+ return torch .nn
41
+ except ImportError :
42
+ raise Exception ("Unable to import torch.nn, install dependency" )
43
+
44
+
45
+ def import_torch ():
46
+ """Import torch"""
47
+ try :
48
+ import torch
49
+ return torch
50
+ except ImportError :
51
+ raise Exception ("Unable to import torch, install dependency" )
52
+
53
+
54
+ def import_Auto_Config ():
55
+ """Import transformers"""
56
+ try :
57
+ from transformers import AutoConfig
58
+ return AutoConfig
59
+ except ImportError :
60
+ raise Exception ("Unable to import transformers.AutoConfig, install Transformers dependency" )
61
+
62
+
63
+ def import_Auto_Model ():
64
+ """Import transformers"""
65
+ try :
66
+ from transformers import AutoModel
67
+ return AutoModel
68
+ except ImportError :
69
+ raise Exception ("Unable to import transformers.AutoModel, install Transformers dependency" )
70
+
71
+
72
+ def import_model_info ():
73
+ """Import model info from huggingface_hub"""
74
+ try :
75
+ from huggingface_hub import model_info
76
+
77
+ return model_info
78
+ except ImportError :
79
+ raise Exception ("Unable to import model_info, check if huggingface_hub is installed" )
80
+
81
+
43
82
def get_max_layer_size (
44
- modules : List [Tuple [str , torch . nn .Module ]],
83
+ modules : List [Tuple [str , import_torch_nn () .Module ]],
45
84
module_sizes : Dict [str , int ],
46
85
no_split_module_classes : List [str ],
47
86
):
@@ -73,12 +112,9 @@ def get_max_layer_size(
73
112
while len (modules_to_treat ) > 0 :
74
113
module_name , module = modules_to_treat .pop (0 )
75
114
modules_children = (
76
- list (module .named_children ()) if isinstance (module , torch . nn .Module ) else []
115
+ list (module .named_children ()) if isinstance (module , import_torch_nn () .Module ) else []
77
116
)
78
- if (
79
- len (modules_children ) == 0
80
- or module .__class__ .__name__ in no_split_module_classes
81
- ):
117
+ if len (modules_children ) == 0 or module .__class__ .__name__ in no_split_module_classes :
82
118
# No splitting this one so we compare to the max_size
83
119
size = module_sizes [module_name ]
84
120
if size > max_size :
@@ -93,16 +129,16 @@ def get_max_layer_size(
93
129
return max_size , layer_names
94
130
95
131
96
- def _get_proper_dtype (dtype : Union [str , torch .device ]) -> torch .dtype :
132
+ def _get_proper_dtype (dtype : Union [str , import_torch () .device ]) -> import_torch () .dtype :
97
133
"""Just does torch.dtype(dtype) if necessary."""
98
134
if isinstance (dtype , str ):
99
135
# We accept "torch.float16" or just "float16"
100
136
dtype = dtype .replace ("torch." , "" )
101
- dtype = getattr (torch , dtype )
137
+ dtype = getattr (import_torch () , dtype )
102
138
return dtype
103
139
104
140
105
- def dtype_byte_size (dtype : torch .dtype ):
141
+ def dtype_byte_size (dtype : import_torch () .dtype ):
106
142
"""Returns the size (in bytes) occupied by one parameter of type `dtype`.
107
143
108
144
Example:
@@ -111,7 +147,7 @@ def dtype_byte_size(dtype: torch.dtype):
111
147
4
112
148
```
113
149
"""
114
- if dtype == torch .bool : # pylint: disable=R1705
150
+ if dtype == import_torch () .bool : # pylint: disable=R1705
115
151
return 1 / 8
116
152
elif dtype == CustomDtype .INT2 :
117
153
return 1 / 4
@@ -127,7 +163,7 @@ def dtype_byte_size(dtype: torch.dtype):
127
163
128
164
129
165
def named_module_tensors (
130
- module : nn .Module ,
166
+ module : import_torch_nn () .Module ,
131
167
include_buffers : bool = True ,
132
168
recurse : bool = False ,
133
169
remove_non_persistent : bool = False ,
@@ -162,7 +198,7 @@ def named_module_tensors(
162
198
yield named_buffer
163
199
164
200
165
- def get_non_persistent_buffers (module : nn .Module , recurse : bool = False ):
201
+ def get_non_persistent_buffers (module : import_torch_nn () .Module , recurse : bool = False ):
166
202
"""Gather all non persistent buffers of a given modules into a set
167
203
168
204
Args:
@@ -182,21 +218,17 @@ def get_non_persistent_buffers(module: nn.Module, recurse: bool = False):
182
218
183
219
184
220
def compute_module_sizes (
185
- model : nn .Module ,
186
- dtype : Optional [Union [str , torch .device ]] = None ,
187
- special_dtypes : Optional [Dict [str , Union [str , torch .device ]]] = None ,
221
+ model : import_torch_nn () .Module ,
222
+ dtype : Optional [Union [str , import_torch () .device ]] = None ,
223
+ special_dtypes : Optional [Dict [str , Union [str , import_torch () .device ]]] = None ,
188
224
):
189
225
"""Compute the size of each submodule of a given model."""
190
226
if dtype is not None :
191
227
dtype = _get_proper_dtype (dtype )
192
228
dtype_size = dtype_byte_size (dtype )
193
229
if special_dtypes is not None :
194
- special_dtypes = {
195
- key : _get_proper_dtype (dtyp ) for key , dtyp in special_dtypes .items ()
196
- }
197
- special_dtypes_size = {
198
- key : dtype_byte_size (dtyp ) for key , dtyp in special_dtypes .items ()
199
- }
230
+ special_dtypes = {key : _get_proper_dtype (dtyp ) for key , dtyp in special_dtypes .items ()}
231
+ special_dtypes_size = {key : dtype_byte_size (dtyp ) for key , dtyp in special_dtypes .items ()}
200
232
module_sizes = defaultdict (int )
201
233
for name , tensor in named_module_tensors (model , recurse = True ):
202
234
if special_dtypes is not None and name in special_dtypes :
@@ -216,7 +248,7 @@ def compute_module_sizes(
216
248
return module_sizes
217
249
218
250
219
- def calculate_maximum_sizes (model : torch . nn .Module ):
251
+ def calculate_maximum_sizes (model : import_torch_nn () .Module ):
220
252
"""Computes the total size of the model and its largest layer"""
221
253
sizes = compute_module_sizes (model )
222
254
# `transformers` models store this information for us
@@ -246,6 +278,7 @@ def convert_bytes(size):
246
278
247
279
def verify_on_hub (repo : str , token : str = None ):
248
280
"""Verifies that the model is on the hub and returns the model info."""
281
+ model_info = import_model_info ()
249
282
try :
250
283
return model_info (repo , token = token )
251
284
except ValueError :
@@ -279,7 +312,7 @@ def create_empty_model(
279
312
`torch.nn.Module`: The torch model that has been initialized on the `meta` device.
280
313
281
314
"""
282
- model_info = verify_on_hub (model_name , access_token ) # pylint: disable=W0621
315
+ model_info = verify_on_hub (model_name , access_token ) # pylint: disable=W0621
283
316
# Simplified errors
284
317
if model_info == "gated" : # pylint: disable=R1720
285
318
raise ValueError (
@@ -309,21 +342,25 @@ def create_empty_model(
309
342
)
310
343
311
344
auto_map = model_info .config .get ("auto_map" , False )
312
- config = AutoConfig .from_pretrained (
345
+ config = import_Auto_Config () .from_pretrained (
313
346
model_name , trust_remote_code = trust_remote_code , token = access_token
314
347
)
315
348
316
349
with init_empty_weights ():
317
350
# remote code could specify a specific `AutoModel` class in the `auto_map`
318
- constructor = AutoModel
351
+ constructor = import_Auto_Model ()
319
352
if isinstance (auto_map , dict ):
320
353
value = None
321
354
for key in auto_map .keys ():
322
355
if key .startswith ("AutoModelFor" ):
323
356
value = key
324
357
break
325
358
if value is not None :
326
- constructor = getattr (transformers , value )
359
+ try :
360
+ import transformers
361
+ constructor = getattr (transformers , value )
362
+ except ImportError :
363
+ raise Exception ("Unable to import transformers, install dependency" )
327
364
model = constructor .from_config (config , trust_remote_code = trust_remote_code )
328
365
else :
329
366
raise ValueError (
@@ -343,17 +380,16 @@ def init_empty_weights(include_buffers: bool = None):
343
380
"""
344
381
if include_buffers is None :
345
382
include_buffers = parse_flag_from_env ("ACCELERATE_INIT_INCLUDE_BUFFERS" , False )
346
- with init_on_device (torch .device ("meta" ), # pylint: disable=E1129
347
- include_buffers = include_buffers ) as f :
383
+ with init_on_device ( # pylint: disable=E1129
384
+ import_torch ().device ("meta" ), include_buffers = include_buffers
385
+ ) as f :
348
386
yield f
349
387
350
388
351
389
def parse_flag_from_env (key , default = False ):
352
390
"""Returns truthy value for `key` from the env if available else the default."""
353
391
value = os .environ .get (key , str (default ))
354
- return (
355
- str_to_bool (value ) == 1
356
- ) # As its name indicates `str_to_bool` actually returns an int...
392
+ return str_to_bool (value ) == 1 # As its name indicates `str_to_bool` actually returns an int...
357
393
358
394
359
395
def str_to_bool (value ) -> int :
@@ -371,14 +407,15 @@ def str_to_bool(value) -> int:
371
407
raise ValueError (f"invalid truth value { value } " )
372
408
373
409
374
- def init_on_device (device : torch .device , include_buffers : bool = None ):
410
+ @contextmanager
411
+ def init_on_device (device : import_torch ().device , include_buffers : bool = None ):
375
412
"""A context manager under which models are initialized"""
376
413
if include_buffers is None :
377
414
include_buffers = parse_flag_from_env ("ACCELERATE_INIT_INCLUDE_BUFFERS" , False )
378
415
379
- old_register_parameter = nn .Module .register_parameter
416
+ old_register_parameter = import_torch_nn () .Module .register_parameter
380
417
if include_buffers :
381
- old_register_buffer = nn .Module .register_buffer
418
+ old_register_buffer = import_torch_nn () .Module .register_buffer
382
419
383
420
def register_empty_parameter (module , name , param ):
384
421
"""Doctype: register_empty_parameter"""
@@ -387,9 +424,7 @@ def register_empty_parameter(module, name, param):
387
424
param_cls = type (module ._parameters [name ])
388
425
kwargs = module ._parameters [name ].__dict__
389
426
kwargs ["requires_grad" ] = param .requires_grad
390
- module ._parameters [name ] = param_cls (
391
- module ._parameters [name ].to (device ), ** kwargs
392
- )
427
+ module ._parameters [name ] = param_cls (module ._parameters [name ].to (device ), ** kwargs )
393
428
394
429
def register_empty_buffer (module , name , buffer , persistent = True ):
395
430
"""Doctype: register_empty_buffer"""
@@ -400,40 +435,41 @@ def register_empty_buffer(module, name, buffer, persistent=True):
400
435
# Patch tensor creation
401
436
if include_buffers :
402
437
tensor_constructors_to_patch = {
403
- torch_function_name : getattr (torch , torch_function_name )
438
+ torch_function_name : getattr (import_torch () , torch_function_name )
404
439
for torch_function_name in ["empty" , "zeros" , "ones" , "full" ]
405
440
}
406
441
else :
407
442
tensor_constructors_to_patch = {}
408
443
409
444
def patch_tensor_constructor (fn ):
410
445
"""Doctype: patch_tensor_constructor"""
446
+
411
447
def wrapper (* args , ** kwargs ):
412
448
kwargs ["device" ] = device
413
449
return fn (* args , ** kwargs )
414
450
415
451
return wrapper
416
452
417
453
try :
418
- nn .Module .register_parameter = register_empty_parameter
454
+ import_torch_nn () .Module .register_parameter = register_empty_parameter
419
455
if include_buffers :
420
- nn .Module .register_buffer = register_empty_buffer
456
+ import_torch_nn () .Module .register_buffer = register_empty_buffer
421
457
for torch_function_name in tensor_constructors_to_patch .keys ():
422
458
setattr (
423
- torch ,
459
+ import_torch () ,
424
460
torch_function_name ,
425
- patch_tensor_constructor (getattr (torch , torch_function_name )),
461
+ patch_tensor_constructor (getattr (import_torch () , torch_function_name )),
426
462
)
427
463
yield
428
464
finally :
429
- nn .Module .register_parameter = old_register_parameter
465
+ import_torch_nn () .Module .register_parameter = old_register_parameter
430
466
if include_buffers :
431
- nn .Module .register_buffer = old_register_buffer
467
+ import_torch_nn () .Module .register_buffer = old_register_buffer
432
468
for (
433
469
torch_function_name ,
434
470
old_torch_function ,
435
471
) in tensor_constructors_to_patch .items ():
436
- setattr (torch , torch_function_name , old_torch_function )
472
+ setattr (import_torch () , torch_function_name , old_torch_function )
437
473
438
474
439
475
def create_ascii_table (headers : list , rows : list , title : str ):
@@ -451,8 +487,10 @@ def create_ascii_table(headers: list, rows: list, title: str):
451
487
diff = 0
452
488
453
489
def make_row (left_char , middle_char , right_char ):
454
- return f"{ left_char } { middle_char .join ([in_between * n for n in column_widths ])} " \
455
- f"{ in_between * diff } { right_char } "
490
+ return (
491
+ f"{ left_char } { middle_char .join ([in_between * n for n in column_widths ])} "
492
+ f"{ in_between * diff } { right_char } "
493
+ )
456
494
457
495
separator = make_row ("├" , "┼" , "┤" )
458
496
if len (title ) > sum (column_widths ):
@@ -487,9 +525,7 @@ def estimate_command_parser(subparsers=None):
487
525
description = "Model size estimator for fitting a model onto CUDA memory."
488
526
)
489
527
490
- parser .add_argument (
491
- "model_name" , type = str , help = "The model name on the Hugging Face Hub."
492
- )
528
+ parser .add_argument ("model_name" , type = str , help = "The model name on the Hugging Face Hub." )
493
529
parser .add_argument (
494
530
"--library_name" ,
495
531
type = str ,
@@ -501,9 +537,9 @@ def estimate_command_parser(subparsers=None):
501
537
"--dtypes" ,
502
538
type = str ,
503
539
nargs = "+" ,
504
- default = [' float32' , ' float16' , ' int8' , ' int4' ],
540
+ default = [" float32" , " float16" , " int8" , " int4" ],
505
541
help = "The dtypes to use for the model, must be one (or many) of "
506
- "`float32`, `float16`, `int8`, and `int4`" ,
542
+ "`float32`, `float16`, `int8`, and `int4`" ,
507
543
choices = ["float32" , "float16" , "int8" , "int4" ],
508
544
)
509
545
parser .add_argument (
0 commit comments