Skip to content

Can't register a model in model registry without specifying inference_instances and transform_instances #3222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
acere opened this issue Jul 10, 2022 · 1 comment
Labels
component: pipelines Relates to the SageMaker Pipeline Platform type: bug

Comments

@acere
Copy link

acere commented Jul 10, 2022

Describe the bug
Registering a model using SageMaker SDK model.register() without specifying inference_instances and transform_instances fails with the following error:

ParamValidationError: Parameter validation failed:
Invalid type for parameter InferenceSpecification.SupportedRealtimeInferenceInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>
Invalid type for parameter InferenceSpecification.SupportedTransformInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>

The same operation via boto3, sagemaker_client.create_model_package() instead completes successfully.

I traced the issue to these two lines of code:

"SupportedRealtimeInferenceInstanceTypes": inference_instances,
"SupportedTransformInstanceTypes": transform_instances,

When inference_instances and transform_instances are None in model.register(), they should not be included at all to the inference_specification dictionary.

To reproduce
The bug can be reproduced running this code block in a SageMaker (studio) notebook:

import sagemaker

model_data = "test.tar.gz"
model_package_group_name = "test-model-group-name"

!touch {model_data}

model_data_uri = sagemaker.s3.S3Uploader.upload(
    local_path=model_data,
    desired_s3_uri=f"s3://{sagemaker.Session().default_bucket()}/{model_package_group_name}",
)


inference_image = sagemaker.image_uris.retrieve(
    framework="xgboost",
    region=sagemaker.Session().boto_region_name,
    image_scope="inference",
    version="latest",
)

model = sagemaker.Model(
    image_uri=inference_image,
    model_data=model_data_uri,
    sagemaker_session=sagemaker.Session(),
)

model.register(
    content_types=["application/json"],
    response_types=["application/json"],
    model_package_group_name=model_package_group_name,
)

Expected behavior
I would expect the model to be registered

Screenshots or logs

this is the error returned when running the code block above

---------------------------------------------------------------------------
ParamValidationError                      Traceback (most recent call last)
Input In [2], in <cell line: 27>()
     14 inference_image = sagemaker.image_uris.retrieve(
     15     framework="xgboost",
     16     region=sagemaker.Session().boto_region_name,
     17     image_scope="inference",
     18     version="latest",
     19 )
     21 model = sagemaker.Model(
     22     image_uri=inference_image,
     23     model_data=model_data_uri,
     24     sagemaker_session=sagemaker.Session(),
     25 )
---> 27 model.register(
     28     content_types=["application/json"],
     29     response_types=["application/json"],
     30     model_package_group_name=model_package_group_name,
     31 )

File /opt/conda/lib/python3.8/site-packages/sagemaker/workflow/pipeline_context.py:209, in runnable_by_pipeline.<locals>.wrapper(*args, **kwargs)
    206     run_func(*args, **kwargs)
    207     return self_instance.sagemaker_session.context
--> 209 return run_func(*args, **kwargs)

File /opt/conda/lib/python3.8/site-packages/sagemaker/model.py:409, in Model.register(self, content_types, response_types, inference_instances, transform_instances, model_package_name, model_package_group_name, image_uri, model_metrics, metadata_properties, marketplace_cert, approval_status, description, drift_check_baselines, customer_metadata_properties, validation_specification, domain, task, sample_payload_url, framework, framework_version, nearest_model_name, data_input_configuration)
    384     container_def = {
    385         "Image": self.image_uri,
    386         "ModelDataUrl": self.model_data,
    387     }
    389 model_pkg_args = sagemaker.get_model_package_args(
    390     content_types,
    391     response_types,
   (...)
    407     task=task,
    408 )
--> 409 model_package = self.sagemaker_session.create_model_package_from_containers(
    410     **model_pkg_args
    411 )
    412 if isinstance(self.sagemaker_session, PipelineSession):
    413     return None

File /opt/conda/lib/python3.8/site-packages/sagemaker/session.py:2896, in Session.create_model_package_from_containers(self, containers, content_types, response_types, inference_instances, transform_instances, model_package_name, model_package_group_name, model_metrics, metadata_properties, marketplace_cert, approval_status, description, drift_check_baselines, customer_metadata_properties, validation_specification, domain, sample_payload_url, task)
   2891             self.sagemaker_client.create_model_package_group(
   2892                 ModelPackageGroupName=request["ModelPackageGroupName"]
   2893             )
   2894     return self.sagemaker_client.create_model_package(**request)
-> 2896 return self._intercept_create_request(
   2897     model_pkg_request, submit, self.create_model_package_from_containers.__name__
   2898 )

File /opt/conda/lib/python3.8/site-packages/sagemaker/session.py:4230, in Session._intercept_create_request(self, request, create, func_name)
   4217 def _intercept_create_request(
   4218     self, request: typing.Dict, create, func_name: str = None  # pylint: disable=unused-argument
   4219 ):
   4220     """This function intercepts the create job request.
   4221 
   4222     PipelineSession inherits this Session class and will override
   (...)
   4228         func_name (str): the name of the function needed intercepting
   4229     """
-> 4230     return create(request)

File /opt/conda/lib/python3.8/site-packages/sagemaker/session.py:2894, in Session.create_model_package_from_containers.<locals>.submit(request)
   2890     except ClientError:
   2891         self.sagemaker_client.create_model_package_group(
   2892             ModelPackageGroupName=request["ModelPackageGroupName"]
   2893         )
-> 2894 return self.sagemaker_client.create_model_package(**request)

File /opt/conda/lib/python3.8/site-packages/botocore/client.py:508, in ClientCreator._create_api_method.<locals>._api_call(self, *args, **kwargs)
    504     raise TypeError(
    505         f"{py_operation_name}() only accepts keyword arguments."
    506     )
    507 # The "self" in this scope is referring to the BaseClient.
--> 508 return self._make_api_call(operation_name, kwargs)

File /opt/conda/lib/python3.8/site-packages/botocore/client.py:874, in BaseClient._make_api_call(self, operation_name, api_params)
    865     logger.debug(
    866         'Warning: %s.%s() is deprecated', service_name, operation_name
    867     )
    868 request_context = {
    869     'client_region': self.meta.region_name,
    870     'client_config': self.meta.config,
    871     'has_streaming_input': operation_model.has_streaming_input,
    872     'auth_type': operation_model.auth_type,
    873 }
--> 874 request_dict = self._convert_to_request_dict(
    875     api_params, operation_model, context=request_context
    876 )
    877 resolve_checksum_context(request_dict, operation_model, api_params)
    879 service_id = self._service_model.service_id.hyphenize()

File /opt/conda/lib/python3.8/site-packages/botocore/client.py:935, in BaseClient._convert_to_request_dict(self, api_params, operation_model, context)
    929 def _convert_to_request_dict(
    930     self, api_params, operation_model, context=None
    931 ):
    932     api_params = self._emit_api_params(
    933         api_params, operation_model, context
    934     )
--> 935     request_dict = self._serializer.serialize_to_request(
    936         api_params, operation_model
    937     )
    938     if not self._client_config.inject_host_prefix:
    939         request_dict.pop('host_prefix', None)

File /opt/conda/lib/python3.8/site-packages/botocore/validate.py:381, in ParamValidationDecorator.serialize_to_request(self, parameters, operation_model)
    377     report = self._param_validator.validate(
    378         parameters, operation_model.input_shape
    379     )
    380     if report.has_errors():
--> 381         raise ParamValidationError(report=report.generate_report())
    382 return self._serializer.serialize_to_request(
    383     parameters, operation_model
    384 )

ParamValidationError: Parameter validation failed:
Invalid type for parameter InferenceSpecification.SupportedRealtimeInferenceInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>
Invalid type for parameter InferenceSpecification.SupportedTransformInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>

System information
A description of your system. Please provide:

  • SageMaker Python SDK version:
    2.99.0
  • Framework name (eg. PyTorch) or algorithm (eg. KMeans):
    Any
  • Framework version:
    Any
  • Python version:
    3.8
  • CPU or GPU:
    Any
  • Custom Docker image (Y/N):
    N
    Additional context
@navaj0 navaj0 added the component: pipelines Relates to the SageMaker Pipeline Platform label Jul 11, 2022
@qidewenwhen
Copy link
Member

Hi @acere, the fix has been merged in v2.100.0. Please check if it works for you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component: pipelines Relates to the SageMaker Pipeline Platform type: bug
Projects
None yet
Development

No branches or pull requests

3 participants