Skip to content

Commit 1f54684

Browse files
rahven14JoseJuan98
authored andcommitted
feat: infer framework and version (aws#3247)
1 parent fb5344c commit 1f54684

File tree

19 files changed

+700
-43
lines changed

19 files changed

+700
-43
lines changed

src/sagemaker/chainer/model.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,108 @@ def __init__(
140140

141141
self.model_server_workers = model_server_workers
142142

143+
def register(
144+
self,
145+
content_types,
146+
response_types,
147+
inference_instances,
148+
transform_instances,
149+
model_package_name=None,
150+
model_package_group_name=None,
151+
image_uri=None,
152+
model_metrics=None,
153+
metadata_properties=None,
154+
marketplace_cert=False,
155+
approval_status=None,
156+
description=None,
157+
drift_check_baselines=None,
158+
customer_metadata_properties=None,
159+
domain=None,
160+
sample_payload_url=None,
161+
task=None,
162+
framework=None,
163+
framework_version=None,
164+
nearest_model_name=None,
165+
data_input_configuration=None,
166+
):
167+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
168+
169+
Args:
170+
content_types (list): The supported MIME types for the input data.
171+
response_types (list): The supported MIME types for the output data.
172+
inference_instances (list): A list of the instance types that are used to
173+
generate inferences in real-time.
174+
transform_instances (list): A list of the instance types on which a transformation
175+
job can be run or on which an endpoint can be deployed.
176+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
177+
using `model_package_name` makes the Model Package un-versioned (default: None).
178+
model_package_group_name (str): Model Package Group name, exclusive to
179+
`model_package_name`, using `model_package_group_name` makes the Model Package
180+
versioned (default: None).
181+
image_uri (str): Inference image uri for the container. Model class' self.image will
182+
be used if it is None (default: None).
183+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
184+
metadata_properties (MetadataProperties): MetadataProperties (default: None).
185+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
186+
for AWS Marketplace (default: False).
187+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
188+
or "PendingManualApproval" (default: "PendingManualApproval").
189+
description (str): Model Package description (default: None).
190+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
191+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
192+
metadata properties (default: None).
193+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
194+
"MACHINE_LEARNING" (default: None).
195+
sample_payload_url (str): The S3 path where the sample payload is stored
196+
(default: None).
197+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
198+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
199+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
200+
framework (str): Machine learning framework of the model package container image
201+
(default: None).
202+
framework_version (str): Framework version of the Model Package Container Image
203+
(default: None).
204+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
205+
Amazon SageMaker Inference Recommender (default: None).
206+
data_input_configuration (str): Input object for the model (default: None).
207+
208+
Returns:
209+
str: A string of SageMaker Model Package ARN.
210+
"""
211+
instance_type = inference_instances[0]
212+
self._init_sagemaker_session_if_does_not_exist(instance_type)
213+
214+
if image_uri:
215+
self.image_uri = image_uri
216+
if not self.image_uri:
217+
self.image_uri = self.serving_image_uri(
218+
region_name=self.sagemaker_session.boto_session.region_name,
219+
instance_type=instance_type,
220+
)
221+
return super(ChainerModel, self).register(
222+
content_types,
223+
response_types,
224+
inference_instances,
225+
transform_instances,
226+
model_package_name,
227+
model_package_group_name,
228+
image_uri,
229+
model_metrics,
230+
metadata_properties,
231+
marketplace_cert,
232+
approval_status,
233+
description,
234+
drift_check_baselines=drift_check_baselines,
235+
customer_metadata_properties=customer_metadata_properties,
236+
domain=domain,
237+
sample_payload_url=sample_payload_url,
238+
task=task,
239+
framework=(framework or self._framework_name).upper(),
240+
framework_version=framework_version or self.framework_version,
241+
nearest_model_name=nearest_model_name,
242+
data_input_configuration=data_input_configuration,
243+
)
244+
143245
def prepare_container_def(
144246
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
145247
):

src/sagemaker/huggingface/model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ def _validate_pt_tf_versions(pytorch_version, tensorflow_version, image_uri):
8585
)
8686

8787

88+
def fetch_framework_and_framework_version(tensorflow_version, pytorch_version):
89+
"""Function to check the framework used in HuggingFace class"""
90+
91+
if tensorflow_version is not None: # pylint: disable=no-member
92+
return ("tensorflow", tensorflow_version) # pylint: disable=no-member
93+
return ("pytorch", pytorch_version) # pylint: disable=no-member
94+
95+
8896
class HuggingFaceModel(FrameworkModel):
8997
"""A Hugging Face SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
9098

@@ -387,8 +395,16 @@ def register(
387395
domain=domain,
388396
sample_payload_url=sample_payload_url,
389397
task=task,
390-
framework=framework,
391-
framework_version=framework_version,
398+
framework=(
399+
framework
400+
or fetch_framework_and_framework_version(
401+
self.tensorflow_version, self.pytorch_version
402+
)[0]
403+
).upper(),
404+
framework_version=framework_version
405+
or fetch_framework_and_framework_version(self.tensorflow_version, self.pytorch_version)[
406+
1
407+
],
392408
nearest_model_name=nearest_model_name,
393409
data_input_configuration=data_input_configuration,
394410
)

src/sagemaker/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,12 @@ def register(
374374

375375
if model_package_group_name is not None:
376376
container_def = self.prepare_container_def()
377-
update_container_with_inference_params(
377+
container_def = update_container_with_inference_params(
378378
framework=framework,
379379
framework_version=framework_version,
380380
nearest_model_name=nearest_model_name,
381381
data_input_configuration=data_input_configuration,
382-
container_obj=container_def,
382+
container_def=container_def,
383383
)
384384
else:
385385
container_def = {

src/sagemaker/mxnet/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def register(
238238
domain=domain,
239239
sample_payload_url=sample_payload_url,
240240
task=task,
241-
framework=framework,
242-
framework_version=framework_version,
241+
framework=(framework or self._framework_name).upper(),
242+
framework_version=framework_version or self.framework_version,
243243
nearest_model_name=nearest_model_name,
244244
data_input_configuration=data_input_configuration,
245245
)

src/sagemaker/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,12 @@ def register(
340340
container_def = self.pipeline_container_def(
341341
inference_instances[0] if inference_instances else None
342342
)
343-
update_container_with_inference_params(
343+
container_def = update_container_with_inference_params(
344344
framework=framework,
345345
framework_version=framework_version,
346346
nearest_model_name=nearest_model_name,
347347
data_input_configuration=data_input_configuration,
348-
container_list=container_def,
348+
container_def=container_def,
349349
)
350350
else:
351351
container_def = [

src/sagemaker/pytorch/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def register(
239239
domain=domain,
240240
sample_payload_url=sample_payload_url,
241241
task=task,
242-
framework=framework,
243-
framework_version=framework_version,
242+
framework=(framework or self._framework_name).upper(),
243+
framework_version=framework_version or self.framework_version,
244244
nearest_model_name=nearest_model_name,
245245
data_input_configuration=data_input_configuration,
246246
)

src/sagemaker/sklearn/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@ def register(
233233
domain=domain,
234234
sample_payload_url=sample_payload_url,
235235
task=task,
236-
framework=framework,
237-
framework_version=framework_version,
236+
framework=(framework or self._framework_name).upper(),
237+
framework_version=framework_version or self.framework_version,
238238
nearest_model_name=nearest_model_name,
239239
data_input_configuration=data_input_configuration,
240240
)

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ def register(
285285
domain=domain,
286286
sample_payload_url=sample_payload_url,
287287
task=task,
288-
framework=framework,
289-
framework_version=framework_version,
288+
framework=(framework or self._framework_name).upper(),
289+
framework_version=framework_version or self.framework_version,
290290
nearest_model_name=nearest_model_name,
291291
data_input_configuration=data_input_configuration,
292292
)

src/sagemaker/utils.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def update_container_with_inference_params(
739739
framework_version=None,
740740
nearest_model_name=None,
741741
data_input_configuration=None,
742-
container_obj=None,
742+
container_def=None,
743743
container_list=None,
744744
):
745745
"""Function to check if inference recommender parameters exist and update container.
@@ -752,28 +752,30 @@ def update_container_with_inference_params(
752752
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
753753
Amazon SageMaker Inference Recommender (default: None).
754754
data_input_configuration (str): Input object for the model (default: None).
755-
container_obj (dict): object to be updated.
755+
container_def (dict): object to be updated.
756756
container_list (list): list to be updated.
757757
758758
Returns:
759759
dict: dict with inference recommender params
760760
"""
761761

762-
if framework is not None and framework_version is not None and nearest_model_name is not None:
763-
if container_list is not None:
764-
for obj in container_list:
765-
construct_container_object(
766-
obj, data_input_configuration, framework, framework_version, nearest_model_name
767-
)
768-
if container_obj is not None:
762+
if container_list is not None:
763+
for obj in container_list:
769764
construct_container_object(
770-
container_obj,
771-
data_input_configuration,
772-
framework,
773-
framework_version,
774-
nearest_model_name,
765+
obj, data_input_configuration, framework, framework_version, nearest_model_name
775766
)
776767

768+
if container_def is not None:
769+
construct_container_object(
770+
container_def,
771+
data_input_configuration,
772+
framework,
773+
framework_version,
774+
nearest_model_name,
775+
)
776+
777+
return container_list or container_def
778+
777779

778780
def construct_container_object(
779781
obj, data_input_configuration, framework, framework_version, nearest_model_name
@@ -788,20 +790,32 @@ def construct_container_object(
788790
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
789791
Amazon SageMaker Inference Recommender (default: None).
790792
data_input_configuration (str): Input object for the model (default: None).
791-
container_obj (dict): object to be updated.
792-
container_list (list): list to be updated.
793+
obj (dict): object to be updated.
793794
794795
Returns:
795796
dict: container object
796797
"""
797798

798-
obj.update(
799-
{
800-
"Framework": framework,
801-
"FrameworkVersion": framework_version,
802-
"NearestModelName": nearest_model_name,
803-
}
804-
)
799+
if framework is not None:
800+
obj.update(
801+
{
802+
"Framework": framework,
803+
}
804+
)
805+
806+
if framework_version is not None:
807+
obj.update(
808+
{
809+
"FrameworkVersion": framework_version,
810+
}
811+
)
812+
813+
if nearest_model_name is not None:
814+
obj.update(
815+
{
816+
"NearestModelName": nearest_model_name,
817+
}
818+
)
805819

806820
if data_input_configuration is not None:
807821
obj.update(
@@ -811,3 +825,5 @@ def construct_container_object(
811825
},
812826
}
813827
)
828+
829+
return obj

src/sagemaker/workflow/step_collections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def __init__(
250250
)
251251
]
252252

253-
update_container_with_inference_params(
253+
self.container_def_list = update_container_with_inference_params(
254254
framework=framework,
255255
framework_version=framework_version,
256256
nearest_model_name=nearest_model_name,

0 commit comments

Comments
 (0)