Skip to content

Commit c857aca

Browse files
committed
fix: refactored util function
1 parent a538bef commit c857aca

File tree

4 files changed

+60
-70
lines changed

4 files changed

+60
-70
lines changed

src/sagemaker/model.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@
3737
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
3838
from sagemaker.utils import (
3939
unique_name_from_base,
40-
inference_recommender_params_exist,
41-
update_container_object,
40+
update_container_with_inference_params,
4241
)
4342
from sagemaker.async_inference import AsyncInferenceConfig
4443
from sagemaker.predictor_async import AsyncPredictor
@@ -374,14 +373,13 @@ def register(
374373

375374
if model_package_group_name is not None:
376375
container_def = self.prepare_container_def()
377-
if inference_recommender_params_exist(
378-
framework, framework_version, nearest_model_name, data_input_configuration
379-
):
380-
container_def.update(
381-
update_container_object(
382-
framework, framework_version, nearest_model_name, data_input_configuration
383-
)
384-
)
376+
update_container_with_inference_params(
377+
framework=framework,
378+
framework_version=framework_version,
379+
nearest_model_name=nearest_model_name,
380+
data_input_configuration=data_input_configuration,
381+
container_obj=container_def,
382+
)
385383
else:
386384
container_def = {
387385
"Image": self.image_uri,

src/sagemaker/pipeline.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
from sagemaker.session import Session
2323
from sagemaker.utils import (
2424
name_from_image,
25-
inference_recommender_params_exist,
26-
update_container_object,
25+
update_container_with_inference_params,
2726
)
2827
from sagemaker.transformer import Transformer
2928
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
@@ -341,18 +340,13 @@ def register(
341340
container_def = self.pipeline_container_def(
342341
inference_instances[0] if inference_instances else None
343342
)
344-
if inference_recommender_params_exist(
345-
framework, framework_version, nearest_model_name, data_input_configuration
346-
):
347-
for container_obj in container_def:
348-
container_obj.update(
349-
update_container_object(
350-
framework,
351-
framework_version,
352-
nearest_model_name,
353-
data_input_configuration,
354-
)
355-
)
343+
update_container_with_inference_params(
344+
framework=framework,
345+
framework_version=framework_version,
346+
nearest_model_name=nearest_model_name,
347+
data_input_configuration=data_input_configuration,
348+
container_list=container_def,
349+
)
356350
else:
357351
container_def = [
358352
{

src/sagemaker/utils.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -724,10 +724,15 @@ def get_data_bucket(self, region_requested=None):
724724
get_ecr_image_uri_prefix = deprecations.removed_function("get_ecr_image_uri_prefix")
725725

726726

727-
def inference_recommender_params_exist(
728-
framework=None, framework_version=None, nearest_model_name=None, data_input_configuration=None
727+
def update_container_with_inference_params(
728+
framework=None,
729+
framework_version=None,
730+
nearest_model_name=None,
731+
data_input_configuration=None,
732+
container_obj=None,
733+
container_list=None,
729734
):
730-
"""Function to check if inference recommender parameters exist.
735+
"""Function to check if inference recommender parameters exist and update container.
731736
732737
Args:
733738
framework (str): Machine learning framework of the model package container image
@@ -737,42 +742,39 @@ def inference_recommender_params_exist(
737742
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
738743
Amazon SageMaker Inference Recommender (default: None).
739744
data_input_configuration (str): Input object for the model (default: None).
745+
container_obj (dict): object to be updated.
746+
container_list (list): list to be updated.
740747
741748
Returns:
742-
bool: all required fields exist or not
749+
dict: dict with inference recommender params
743750
"""
751+
744752
if (
745753
framework is not None
746754
and framework_version is not None
747755
and nearest_model_name is not None
748756
and data_input_configuration is not None
749757
):
750-
return True
751-
return False
752-
753-
754-
def update_container_object(
755-
framework=None, framework_version=None, nearest_model_name=None, data_input_configuration=None
756-
):
757-
"""Update the container_def object with inference recommedender parameters.
758-
759-
Args:
760-
framework (str): Machine learning framework of the model package container image
761-
(default: None).
762-
framework_version (str): Framework version of the Model Package Container Image
763-
(default: None).
764-
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
765-
Amazon SageMaker Inference Recommender (default: None).
766-
data_input_configuration (str): Input object for the model (default: None).
767-
768-
Returns:
769-
dict: inference recommender key, value pairs which updates the object.
770-
"""
771-
return {
772-
"Framework": framework,
773-
"FrameworkVersion": framework_version,
774-
"NearestModelName": nearest_model_name,
775-
"ModelInput": {
776-
"DataInputConfig": data_input_configuration,
777-
},
778-
}
758+
if container_list is not None:
759+
for obj in container_list:
760+
obj.update(
761+
{
762+
"Framework": framework,
763+
"FrameworkVersion": framework_version,
764+
"NearestModelName": nearest_model_name,
765+
"ModelInput": {
766+
"DataInputConfig": data_input_configuration,
767+
},
768+
}
769+
)
770+
if container_obj is not None:
771+
container_obj.update(
772+
{
773+
"Framework": framework,
774+
"FrameworkVersion": framework_version,
775+
"NearestModelName": nearest_model_name,
776+
"ModelInput": {
777+
"DataInputConfig": data_input_configuration,
778+
},
779+
}
780+
)

src/sagemaker/workflow/step_collections.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from sagemaker.workflow.steps import Step, CreateModelStep, TransformStep
2828
from sagemaker.workflow._utils import _RegisterModelStep, _RepackModelStep
2929
from sagemaker.workflow.retry import RetryPolicy
30-
from sagemaker.utils import inference_recommender_params_exist, update_container_object
30+
from sagemaker.utils import update_container_with_inference_params
3131

3232

3333
@attr.s
@@ -246,18 +246,14 @@ def __init__(
246246
inference_instances[0] if inference_instances else None
247247
)
248248
]
249-
if inference_recommender_params_exist(
250-
framework, framework_version, nearest_model_name, data_input_configuration
251-
):
252-
for container_obj in self.container_def_list:
253-
container_obj.update(
254-
update_container_object(
255-
framework,
256-
framework_version,
257-
nearest_model_name,
258-
data_input_configuration,
259-
)
260-
)
249+
250+
update_container_with_inference_params(
251+
framework=framework,
252+
framework_version=framework_version,
253+
nearest_model_name=nearest_model_name,
254+
data_input_configuration=data_input_configuration,
255+
container_list=self.container_def_list,
256+
)
261257

262258
register_model_step = _RegisterModelStep(
263259
name=name,

0 commit comments

Comments
 (0)