13
13
"""This module contains functions for obtaining JumpStart resoure requirements."""
14
14
from __future__ import absolute_import
15
15
16
- from typing import Optional
16
+ from typing import Dict , Optional , Tuple
17
17
18
18
from sagemaker .jumpstart .constants import (
19
19
DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
28
28
from sagemaker .session import Session
29
29
from sagemaker .compute_resource_requirements .resource_requirements import ResourceRequirements
30
30
31
+ REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP : Dict [
32
+ str , Dict [str , Tuple [str , str ]]
33
+ ] = {
34
+ "requests" : {
35
+ "num_accelerators" : ("num_accelerators" , "num_accelerators" ),
36
+ "num_cpus" : ("num_cpus" , "num_cpus" ),
37
+ "copies" : ("copies" , "copy_count" ),
38
+ "min_memory_mb" : ("memory" , "min_memory" ),
39
+ },
40
+ "limits" : {
41
+ "max_memory_mb" : ("memory" , "max_memory" ),
42
+ },
43
+ }
44
+
31
45
32
46
def _retrieve_default_resources (
33
47
model_id : str ,
@@ -37,6 +51,7 @@ def _retrieve_default_resources(
37
51
tolerate_vulnerable_model : bool = False ,
38
52
tolerate_deprecated_model : bool = False ,
39
53
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
54
+ instance_type : Optional [str ] = None ,
40
55
) -> ResourceRequirements :
41
56
"""Retrieves the default resource requirements for the model.
42
57
@@ -60,6 +75,8 @@ def _retrieve_default_resources(
60
75
object, used for SageMaker interactions. If not
61
76
specified, one is created using the default AWS configuration
62
77
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
78
+ instance_type (str): An instance type to optionally supply in order to get
79
+ host requirements specific for the instance type.
63
80
Returns:
64
81
str: The default resource requirements to use for the model or None.
65
82
@@ -87,23 +104,44 @@ def _retrieve_default_resources(
87
104
is_dynamic_container_deployment_supported = (
88
105
model_specs .dynamic_container_deployment_supported
89
106
)
90
- default_resource_requirements = model_specs .hosting_resource_requirements
107
+ default_resource_requirements : Dict [str , int ] = (
108
+ model_specs .hosting_resource_requirements or {}
109
+ )
91
110
else :
92
111
raise NotImplementedError (
93
112
f"Unsupported script scope for retrieving default resource requirements: '{ scope } '"
94
113
)
95
114
115
+ instance_specific_resource_requirements : Dict [str , int ] = (
116
+ model_specs .hosting_instance_type_variants .get_instance_specific_resource_requirements (
117
+ instance_type
118
+ )
119
+ if instance_type
120
+ and getattr (model_specs , "hosting_instance_type_variants" , None ) is not None
121
+ else {}
122
+ )
123
+
124
+ default_resource_requirements = {
125
+ ** default_resource_requirements ,
126
+ ** instance_specific_resource_requirements ,
127
+ }
128
+
96
129
if is_dynamic_container_deployment_supported :
97
- requests = {}
98
- if "num_accelerators" in default_resource_requirements :
99
- requests ["num_accelerators" ] = default_resource_requirements ["num_accelerators" ]
100
- if "min_memory_mb" in default_resource_requirements :
101
- requests ["memory" ] = default_resource_requirements ["min_memory_mb" ]
102
- if "num_cpus" in default_resource_requirements :
103
- requests ["num_cpus" ] = default_resource_requirements ["num_cpus" ]
104
-
105
- limits = {}
106
- if "max_memory_mb" in default_resource_requirements :
107
- limits ["memory" ] = default_resource_requirements ["max_memory_mb" ]
108
- return ResourceRequirements (requests = requests , limits = limits )
130
+
131
+ all_resource_requirement_kwargs = {}
132
+
133
+ for (
134
+ requirement_type ,
135
+ spec_field_to_resource_requirement_map ,
136
+ ) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP .items ():
137
+ requirement_kwargs = {}
138
+ for spec_field , resource_requirement in spec_field_to_resource_requirement_map .items ():
139
+ if spec_field in default_resource_requirements :
140
+ requirement_kwargs [resource_requirement [0 ]] = default_resource_requirements [
141
+ spec_field
142
+ ]
143
+
144
+ all_resource_requirement_kwargs [requirement_type ] = requirement_kwargs
145
+
146
+ return ResourceRequirements (** all_resource_requirement_kwargs )
109
147
return None
0 commit comments