28
28
verify_model_region_and_return_specs ,
29
29
)
30
30
from sagemaker .session import Session
31
+ from sagemaker .jumpstart .types import JumpStartModelSpecs
32
+
33
+
34
+ def _retrieve_hosting_prepacked_artifact_key (
35
+ model_specs : JumpStartModelSpecs , instance_type : str
36
+ ) -> str :
37
+ """Returns instance specific hosting prepacked artifact key or default one as fallback."""
38
+ instance_specific_prepacked_hosting_artifact_key : Optional [str ] = (
39
+ model_specs .hosting_instance_type_variants .get_instance_specific_prepacked_artifact_key (
40
+ instance_type = instance_type
41
+ )
42
+ if instance_type
43
+ and getattr (model_specs , "hosting_instance_type_variants" , None ) is not None
44
+ else None
45
+ )
46
+
47
+ default_prepacked_hosting_artifact_key : Optional [str ] = getattr (
48
+ model_specs , "hosting_prepacked_artifact_key"
49
+ )
50
+
51
+ return (
52
+ instance_specific_prepacked_hosting_artifact_key or default_prepacked_hosting_artifact_key
53
+ )
54
+
55
+
56
+ def _retrieve_hosting_artifact_key (model_specs : JumpStartModelSpecs , instance_type : str ) -> str :
57
+ """Returns instance specific hosting artifact key or default one as fallback."""
58
+ instance_specific_hosting_artifact_key : Optional [str ] = (
59
+ model_specs .hosting_instance_type_variants .get_instance_specific_artifact_key (
60
+ instance_type = instance_type
61
+ )
62
+ if instance_type
63
+ and getattr (model_specs , "hosting_instance_type_variants" , None ) is not None
64
+ else None
65
+ )
66
+
67
+ default_hosting_artifact_key : str = model_specs .hosting_artifact_key
68
+
69
+ return instance_specific_hosting_artifact_key or default_hosting_artifact_key
70
+
71
+
72
+ def _retrieve_training_artifact_key (model_specs : JumpStartModelSpecs , instance_type : str ) -> str :
73
+ """Returns instance specific training artifact key or default one as fallback."""
74
+ instance_specific_training_artifact_key : Optional [str ] = (
75
+ model_specs .training_instance_type_variants .get_instance_specific_artifact_key (
76
+ instance_type = instance_type
77
+ )
78
+ if instance_type
79
+ and getattr (model_specs , "training_instance_type_variants" , None ) is not None
80
+ else None
81
+ )
82
+
83
+ default_training_artifact_key : str = model_specs .training_artifact_key
84
+
85
+ return instance_specific_training_artifact_key or default_training_artifact_key
31
86
32
87
33
88
def _retrieve_model_uri (
34
89
model_id : str ,
35
90
model_version : str ,
36
91
model_scope : Optional [str ] = None ,
92
+ instance_type : Optional [str ] = None ,
37
93
region : Optional [str ] = None ,
38
94
tolerate_vulnerable_model : bool = False ,
39
95
tolerate_deprecated_model : bool = False ,
@@ -50,6 +106,7 @@ def _retrieve_model_uri(
50
106
artifact S3 URI.
51
107
model_scope (str): The model type, i.e. what it is used for.
52
108
Valid values: "training" and "inference".
109
+ instance_type (str): The ML compute instance type for the specified scope. (Default: None).
53
110
region (str): Region for which to retrieve model S3 URI. (Default: None).
54
111
tolerate_vulnerable_model (bool): True if vulnerable versions of model
55
112
specifications should be tolerated (exception not raised). If False, raises an
@@ -84,14 +141,21 @@ def _retrieve_model_uri(
84
141
sagemaker_session = sagemaker_session ,
85
142
)
86
143
144
+ model_artifact_key : str
145
+
87
146
if model_scope == JumpStartScriptScope .INFERENCE :
147
+
148
+ is_prepacked = not model_specs .use_inference_script_uri ()
149
+
88
150
model_artifact_key = (
89
- getattr (model_specs , "hosting_prepacked_artifact_key" , None )
90
- or model_specs .hosting_artifact_key
151
+ _retrieve_hosting_prepacked_artifact_key (model_specs , instance_type )
152
+ if is_prepacked
153
+ else _retrieve_hosting_artifact_key (model_specs , instance_type )
91
154
)
92
155
93
156
elif model_scope == JumpStartScriptScope .TRAINING :
94
- model_artifact_key = model_specs .training_artifact_key
157
+
158
+ model_artifact_key = _retrieve_training_artifact_key (model_specs , instance_type )
95
159
96
160
bucket = os .environ .get (
97
161
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
0 commit comments