Skip to content

fix: make instance type fields as optional #3135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
8 changes: 4 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,8 +1266,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
image_uri=None,
model_package_name=None,
model_package_group_name=None,
Expand All @@ -1288,9 +1288,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
image_uri (str): The container image uri for Model Package, if not specified,
Estimator's training container image will be used (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
Expand Down
10 changes: 5 additions & 5 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -311,9 +311,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned.
Defaults to ``None``.
Expand All @@ -335,7 +335,7 @@ def register(
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
instance_type = inference_instances[0]
instance_type = inference_instances[0] if inference_instances else None
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
Expand Down
15 changes: 7 additions & 8 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -311,14 +311,14 @@ def register(
validation_specification=None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Args:
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand Down Expand Up @@ -348,12 +348,11 @@ def register(
container_def = self.prepare_container_def()
else:
container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data}

model_pkg_args = sagemaker.get_model_package_args(
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=inference_instances,
transform_instances=transform_instances,
model_package_name=model_package_name,
model_package_group_name=model_package_group_name,
model_metrics=model_metrics,
Expand Down
10 changes: 5 additions & 5 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -165,9 +165,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand All @@ -189,7 +189,7 @@ def register(
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
instance_type = inference_instances[0]
instance_type = inference_instances[0] if inference_instances else None
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
Expand Down
14 changes: 7 additions & 7 deletions src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def register(
self,
content_types: list,
response_types: list,
inference_instances: list,
transform_instances: list,
inference_instances: Optional[list] = None,
transform_instances: Optional[list] = None,
model_package_name: Optional[str] = None,
model_package_group_name: Optional[str] = None,
image_uri: Optional[str] = None,
Expand All @@ -285,9 +285,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand All @@ -313,7 +313,7 @@ def register(
if model.model_data is None:
raise ValueError("SageMaker Model Package cannot be created without model data.")
if model_package_group_name is not None:
container_def = self.pipeline_container_def(inference_instances[0])
container_def = self.pipeline_container_def(inference_instances[0] if inference_instances else None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update the instance_type argument in pipeline_container_def and make its default value is None?
https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/pipeline.py#L87

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

else:
container_def = [
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
Expand All @@ -323,8 +323,8 @@ def register(
model_pkg_args = sagemaker.get_model_package_args(
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=inference_instances,
transform_instances=transform_instances,
model_package_name=model_package_name,
model_package_group_name=model_package_group_name,
model_metrics=model_metrics,
Expand Down
10 changes: 5 additions & 5 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -166,9 +166,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand All @@ -190,7 +190,7 @@ def register(
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
instance_type = inference_instances[0]
instance_type = inference_instances[0] if inference_instances else None
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
Expand Down
12 changes: 6 additions & 6 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4202,8 +4202,8 @@ def _intercept_create_request(
def get_model_package_args(
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
model_data=None,
Expand All @@ -4225,9 +4225,9 @@ def get_model_package_args(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand Down Expand Up @@ -4363,9 +4363,9 @@ def get_create_model_package_request(
if validation_specification:
request_dict["ValidationSpecification"] = validation_specification
if containers is not None:
if not all([content_types, response_types, inference_instances, transform_instances]):
if not all([content_types, response_types]):
raise ValueError(
"content_types, response_types, inference_inferences and transform_instances "
"content_types and response_types "
"must be provided if containers is present."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This thrown error was added by sreedes@. Did you check with her to see if she agree on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed the same with Saumitra and got approval for this change

inference_specification = {
Expand Down
10 changes: 5 additions & 5 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -158,9 +158,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand All @@ -179,7 +179,7 @@ def register(
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
instance_type = inference_instances[0]
instance_type = inference_instances[0] if inference_instances else None
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
Expand Down
10 changes: 5 additions & 5 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -212,9 +212,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand All @@ -237,7 +237,7 @@ def register(
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
instance_type = inference_instances[0]
instance_type = inference_instances[0] if inference_instances else None
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/workflow/step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(
name: str,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
estimator: EstimatorBase = None,
model_data=None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
Expand Down Expand Up @@ -217,9 +217,9 @@ def __init__(
kwargs.pop("output_kms_key", None)

if isinstance(model, PipelineModel):
self.container_def_list = model.pipeline_container_def(inference_instances[0])
self.container_def_list = model.pipeline_container_def(inference_instances[0] if inference_instances else None)
elif isinstance(model, Model):
self.container_def_list = [model.prepare_container_def(inference_instances[0])]
self.container_def_list = [model.prepare_container_def(inference_instances[0] if inference_instances else None)]

register_model_step = _RegisterModelStep(
name=name,
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/sagemaker/workflow/test_pipeline_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,39 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock):
assert not register_step_args.create_model_request
assert register_step_args.create_model_package_request
assert len(register_step_args.need_runtime_repack) == 0


def test_pipeline_session_context_for_model_step_without_instance_types(pipeline_session_mock):
model = Model(
name="MyModel",
image_uri="fakeimage",
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
sagemaker_session=pipeline_session_mock,
entry_point=f"{DATA_DIR}/dummy_script.py",
source_dir=f"{DATA_DIR}",
role=_ROLE,
)
# CreateModelStep requires runtime repack
create_step_args = model.create(
instance_type="c4.4xlarge",
accelerator_type="ml.eia1.medium",
)
# The context should be cleaned up before return
assert pipeline_session_mock.context is None
assert create_step_args.create_model_request
assert not create_step_args.create_model_package_request
assert len(create_step_args.need_runtime_repack) == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: seems that you only update the model.register() and this test should focus on your changes - "without instance_types". Thus I guess we don't need to test the model.create case here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.


# _RegisterModelStep does not require runtime repack
model.entry_point = None
model.source_dir = None
register_step_args = model.register(
content_types=["text/csv"],
response_types=["text/csv"],
model_package_group_name="MyModelPackageGroup",
)
# The context should be cleaned up before return
assert not pipeline_session_mock.context
assert not register_step_args.create_model_request
assert register_step_args.create_model_package_request
assert len(register_step_args.need_runtime_repack) == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: same above, these validations should focus on your changes, which is how removing instance_type impacting the register_step_args.create_model_package_request

Please remove the validations of 157, 158 and 160. And elaborate a little bit on line 159 if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as per request

Loading