Skip to content

Commit c3e3c52

Browse files
committed
Fix formatting
1 parent ac65024 commit c3e3c52

File tree

1 file changed

+95
-59
lines changed

1 file changed

+95
-59
lines changed

src/sagemaker/serve/utils/estimate_parser.py

+95-59
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,6 @@
2121
from collections import defaultdict
2222
from contextlib import contextmanager
2323
from typing import Dict, List, Optional, Tuple, Union
24-
from huggingface_hub import model_info
25-
26-
import transformers
27-
from transformers import AutoConfig, AutoModel
28-
29-
import torch
30-
import torch.nn as nn
3124

3225
logger = logging.getLogger(__name__)
3326

@@ -40,8 +33,54 @@ class CustomDtype(enum.Enum):
4033
INT2 = "int2"
4134

4235

36+
def import_torch_nn():
37+
"""Import torch.nn"""
38+
try:
39+
import torch.nn
40+
return torch.nn
41+
except ImportError:
42+
raise Exception("Unable to import torch.nn, install dependency")
43+
44+
45+
def import_torch():
46+
"""Import torch"""
47+
try:
48+
import torch
49+
return torch
50+
except ImportError:
51+
raise Exception("Unable to import torch, install dependency")
52+
53+
54+
def import_Auto_Config():
55+
"""Import transformers"""
56+
try:
57+
from transformers import AutoConfig
58+
return AutoConfig
59+
except ImportError:
60+
raise Exception("Unable to import transformers.AutoConfig, install Transformers dependency")
61+
62+
63+
def import_Auto_Model():
64+
"""Import transformers"""
65+
try:
66+
from transformers import AutoModel
67+
return AutoModel
68+
except ImportError:
69+
raise Exception("Unable to import transformers.AutoModel, install Transformers dependency")
70+
71+
72+
def import_model_info():
73+
"""Import model info from huggingface_hub"""
74+
try:
75+
from huggingface_hub import model_info
76+
77+
return model_info
78+
except ImportError:
79+
raise Exception("Unable to import model_info, check if huggingface_hub is installed")
80+
81+
4382
def get_max_layer_size(
44-
modules: List[Tuple[str, torch.nn.Module]],
83+
modules: List[Tuple[str, import_torch_nn().Module]],
4584
module_sizes: Dict[str, int],
4685
no_split_module_classes: List[str],
4786
):
@@ -73,12 +112,9 @@ def get_max_layer_size(
73112
while len(modules_to_treat) > 0:
74113
module_name, module = modules_to_treat.pop(0)
75114
modules_children = (
76-
list(module.named_children()) if isinstance(module, torch.nn.Module) else []
115+
list(module.named_children()) if isinstance(module, import_torch_nn().Module) else []
77116
)
78-
if (
79-
len(modules_children) == 0
80-
or module.__class__.__name__ in no_split_module_classes
81-
):
117+
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
82118
# No splitting this one so we compare to the max_size
83119
size = module_sizes[module_name]
84120
if size > max_size:
@@ -93,16 +129,16 @@ def get_max_layer_size(
93129
return max_size, layer_names
94130

95131

96-
def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:
132+
def _get_proper_dtype(dtype: Union[str, import_torch().device]) -> import_torch().dtype:
97133
"""Just does torch.dtype(dtype) if necessary."""
98134
if isinstance(dtype, str):
99135
# We accept "torch.float16" or just "float16"
100136
dtype = dtype.replace("torch.", "")
101-
dtype = getattr(torch, dtype)
137+
dtype = getattr(import_torch(), dtype)
102138
return dtype
103139

104140

105-
def dtype_byte_size(dtype: torch.dtype):
141+
def dtype_byte_size(dtype: import_torch().dtype):
106142
"""Returns the size (in bytes) occupied by one parameter of type `dtype`.
107143
108144
Example:
@@ -111,7 +147,7 @@ def dtype_byte_size(dtype: torch.dtype):
111147
4
112148
```
113149
"""
114-
if dtype == torch.bool: # pylint: disable=R1705
150+
if dtype == import_torch().bool: # pylint: disable=R1705
115151
return 1 / 8
116152
elif dtype == CustomDtype.INT2:
117153
return 1 / 4
@@ -127,7 +163,7 @@ def dtype_byte_size(dtype: torch.dtype):
127163

128164

129165
def named_module_tensors(
130-
module: nn.Module,
166+
module: import_torch_nn().Module,
131167
include_buffers: bool = True,
132168
recurse: bool = False,
133169
remove_non_persistent: bool = False,
@@ -162,7 +198,7 @@ def named_module_tensors(
162198
yield named_buffer
163199

164200

165-
def get_non_persistent_buffers(module: nn.Module, recurse: bool = False):
201+
def get_non_persistent_buffers(module: import_torch_nn().Module, recurse: bool = False):
166202
"""Gather all non persistent buffers of a given modules into a set
167203
168204
Args:
@@ -182,21 +218,17 @@ def get_non_persistent_buffers(module: nn.Module, recurse: bool = False):
182218

183219

184220
def compute_module_sizes(
185-
model: nn.Module,
186-
dtype: Optional[Union[str, torch.device]] = None,
187-
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
221+
model: import_torch_nn().Module,
222+
dtype: Optional[Union[str, import_torch().device]] = None,
223+
special_dtypes: Optional[Dict[str, Union[str, import_torch().device]]] = None,
188224
):
189225
"""Compute the size of each submodule of a given model."""
190226
if dtype is not None:
191227
dtype = _get_proper_dtype(dtype)
192228
dtype_size = dtype_byte_size(dtype)
193229
if special_dtypes is not None:
194-
special_dtypes = {
195-
key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()
196-
}
197-
special_dtypes_size = {
198-
key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()
199-
}
230+
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
231+
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
200232
module_sizes = defaultdict(int)
201233
for name, tensor in named_module_tensors(model, recurse=True):
202234
if special_dtypes is not None and name in special_dtypes:
@@ -216,7 +248,7 @@ def compute_module_sizes(
216248
return module_sizes
217249

218250

219-
def calculate_maximum_sizes(model: torch.nn.Module):
251+
def calculate_maximum_sizes(model: import_torch_nn().Module):
220252
"""Computes the total size of the model and its largest layer"""
221253
sizes = compute_module_sizes(model)
222254
# `transformers` models store this information for us
@@ -246,6 +278,7 @@ def convert_bytes(size):
246278

247279
def verify_on_hub(repo: str, token: str = None):
248280
"""Verifies that the model is on the hub and returns the model info."""
281+
model_info = import_model_info()
249282
try:
250283
return model_info(repo, token=token)
251284
except ValueError:
@@ -279,7 +312,7 @@ def create_empty_model(
279312
`torch.nn.Module`: The torch model that has been initialized on the `meta` device.
280313
281314
"""
282-
model_info = verify_on_hub(model_name, access_token) # pylint: disable=W0621
315+
model_info = verify_on_hub(model_name, access_token) # pylint: disable=W0621
283316
# Simplified errors
284317
if model_info == "gated": # pylint: disable=R1720
285318
raise ValueError(
@@ -309,21 +342,25 @@ def create_empty_model(
309342
)
310343

311344
auto_map = model_info.config.get("auto_map", False)
312-
config = AutoConfig.from_pretrained(
345+
config = import_Auto_Config().from_pretrained(
313346
model_name, trust_remote_code=trust_remote_code, token=access_token
314347
)
315348

316349
with init_empty_weights():
317350
# remote code could specify a specific `AutoModel` class in the `auto_map`
318-
constructor = AutoModel
351+
constructor = import_Auto_Model()
319352
if isinstance(auto_map, dict):
320353
value = None
321354
for key in auto_map.keys():
322355
if key.startswith("AutoModelFor"):
323356
value = key
324357
break
325358
if value is not None:
326-
constructor = getattr(transformers, value)
359+
try:
360+
import transformers
361+
constructor = getattr(transformers, value)
362+
except ImportError:
363+
raise Exception("Unable to import transformers, install dependency")
327364
model = constructor.from_config(config, trust_remote_code=trust_remote_code)
328365
else:
329366
raise ValueError(
@@ -343,17 +380,16 @@ def init_empty_weights(include_buffers: bool = None):
343380
"""
344381
if include_buffers is None:
345382
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
346-
with init_on_device(torch.device("meta"), # pylint: disable=E1129
347-
include_buffers=include_buffers) as f:
383+
with init_on_device( # pylint: disable=E1129
384+
import_torch().device("meta"), include_buffers=include_buffers
385+
) as f:
348386
yield f
349387

350388

351389
def parse_flag_from_env(key, default=False):
352390
"""Returns truthy value for `key` from the env if available else the default."""
353391
value = os.environ.get(key, str(default))
354-
return (
355-
str_to_bool(value) == 1
356-
) # As its name indicates `str_to_bool` actually returns an int...
392+
return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int...
357393

358394

359395
def str_to_bool(value) -> int:
@@ -371,14 +407,15 @@ def str_to_bool(value) -> int:
371407
raise ValueError(f"invalid truth value {value}")
372408

373409

374-
def init_on_device(device: torch.device, include_buffers: bool = None):
410+
@contextmanager
411+
def init_on_device(device: import_torch().device, include_buffers: bool = None):
375412
"""A context manager under which models are initialized"""
376413
if include_buffers is None:
377414
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
378415

379-
old_register_parameter = nn.Module.register_parameter
416+
old_register_parameter = import_torch_nn().Module.register_parameter
380417
if include_buffers:
381-
old_register_buffer = nn.Module.register_buffer
418+
old_register_buffer = import_torch_nn().Module.register_buffer
382419

383420
def register_empty_parameter(module, name, param):
384421
"""Doctype: register_empty_parameter"""
@@ -387,9 +424,7 @@ def register_empty_parameter(module, name, param):
387424
param_cls = type(module._parameters[name])
388425
kwargs = module._parameters[name].__dict__
389426
kwargs["requires_grad"] = param.requires_grad
390-
module._parameters[name] = param_cls(
391-
module._parameters[name].to(device), **kwargs
392-
)
427+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
393428

394429
def register_empty_buffer(module, name, buffer, persistent=True):
395430
"""Doctype: register_empty_buffer"""
@@ -400,40 +435,41 @@ def register_empty_buffer(module, name, buffer, persistent=True):
400435
# Patch tensor creation
401436
if include_buffers:
402437
tensor_constructors_to_patch = {
403-
torch_function_name: getattr(torch, torch_function_name)
438+
torch_function_name: getattr(import_torch(), torch_function_name)
404439
for torch_function_name in ["empty", "zeros", "ones", "full"]
405440
}
406441
else:
407442
tensor_constructors_to_patch = {}
408443

409444
def patch_tensor_constructor(fn):
410445
"""Doctype: patch_tensor_constructor"""
446+
411447
def wrapper(*args, **kwargs):
412448
kwargs["device"] = device
413449
return fn(*args, **kwargs)
414450

415451
return wrapper
416452

417453
try:
418-
nn.Module.register_parameter = register_empty_parameter
454+
import_torch_nn().Module.register_parameter = register_empty_parameter
419455
if include_buffers:
420-
nn.Module.register_buffer = register_empty_buffer
456+
import_torch_nn().Module.register_buffer = register_empty_buffer
421457
for torch_function_name in tensor_constructors_to_patch.keys():
422458
setattr(
423-
torch,
459+
import_torch(),
424460
torch_function_name,
425-
patch_tensor_constructor(getattr(torch, torch_function_name)),
461+
patch_tensor_constructor(getattr(import_torch(), torch_function_name)),
426462
)
427463
yield
428464
finally:
429-
nn.Module.register_parameter = old_register_parameter
465+
import_torch_nn().Module.register_parameter = old_register_parameter
430466
if include_buffers:
431-
nn.Module.register_buffer = old_register_buffer
467+
import_torch_nn().Module.register_buffer = old_register_buffer
432468
for (
433469
torch_function_name,
434470
old_torch_function,
435471
) in tensor_constructors_to_patch.items():
436-
setattr(torch, torch_function_name, old_torch_function)
472+
setattr(import_torch(), torch_function_name, old_torch_function)
437473

438474

439475
def create_ascii_table(headers: list, rows: list, title: str):
@@ -451,8 +487,10 @@ def create_ascii_table(headers: list, rows: list, title: str):
451487
diff = 0
452488

453489
def make_row(left_char, middle_char, right_char):
454-
return f"{left_char}{middle_char.join([in_between * n for n in column_widths])}" \
455-
f"{in_between * diff}{right_char}"
490+
return (
491+
f"{left_char}{middle_char.join([in_between * n for n in column_widths])}"
492+
f"{in_between * diff}{right_char}"
493+
)
456494

457495
separator = make_row("├", "┼", "┤")
458496
if len(title) > sum(column_widths):
@@ -487,9 +525,7 @@ def estimate_command_parser(subparsers=None):
487525
description="Model size estimator for fitting a model onto CUDA memory."
488526
)
489527

490-
parser.add_argument(
491-
"model_name", type=str, help="The model name on the Hugging Face Hub."
492-
)
528+
parser.add_argument("model_name", type=str, help="The model name on the Hugging Face Hub.")
493529
parser.add_argument(
494530
"--library_name",
495531
type=str,
@@ -501,9 +537,9 @@ def estimate_command_parser(subparsers=None):
501537
"--dtypes",
502538
type=str,
503539
nargs="+",
504-
default=['float32', 'float16', 'int8', 'int4'],
540+
default=["float32", "float16", "int8", "int4"],
505541
help="The dtypes to use for the model, must be one (or many) of "
506-
"`float32`, `float16`, `int8`, and `int4`",
542+
"`float32`, `float16`, `int8`, and `int4`",
507543
choices=["float32", "float16", "int8", "int4"],
508544
)
509545
parser.add_argument(

0 commit comments

Comments
 (0)