Skip to content

Commit 414a77c

Browse files
committed
feature: Add create inf rec api to session (aws#761)
* feature: Add create inf rec api to session * save * fix error handling in submit. update docstring * add in job_name param Co-authored-by: Gary Wang <[email protected]>
1 parent 37c0d3a commit 414a77c

File tree

2 files changed

+364
-0
lines changed

2 files changed

+364
-0
lines changed

src/sagemaker/session.py

+189
Original file line numberDiff line numberDiff line change
@@ -4655,6 +4655,195 @@ def _intercept_create_request(
46554655
"""
46564656
return create(request)
46574657

4658+
def _create_inference_recommendations_job_request(
4659+
self,
4660+
role: str,
4661+
job_name: str,
4662+
job_description: str,
4663+
framework: str,
4664+
sample_payload_url: str,
4665+
supported_content_types: List[str],
4666+
model_package_version_arn: str = None,
4667+
job_duration_in_seconds: int = None,
4668+
job_type: str = "Default",
4669+
framework_version: str = None,
4670+
nearest_model_name: str = None,
4671+
supported_instance_types: List[str] = None,
4672+
endpoint_configurations: List[Dict[str, Any]] = None,
4673+
traffic_pattern: Dict[str, Any] = None,
4674+
stopping_conditions: Dict[str, Any] = None,
4675+
resource_limit: Dict[str, Any] = None,
4676+
) -> Dict[str, Any]:
4677+
"""Get request dictionary for CreateInferenceRecommendationsJob API.
4678+
4679+
Args:
4680+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
4681+
jobs and APIs that create Amazon SageMaker endpoints use this role to access
4682+
training data and model artifacts.
4683+
You must grant sufficient permissions to this role.
4684+
job_name (str): The name of the Inference Recommendations Job.
4685+
job_description (str): A description of the Inference Recommendations Job.
4686+
framework (str): The machine learning framework of the Image URI.
4687+
sample_payload_url (str): The S3 path where the sample payload is stored.
4688+
supported_content_types (List[str]): The supported MIME types for the input data.
4689+
model_package_version_arn (str): The Amazon Resource Name (ARN) of a
4690+
versioned model package.
4691+
job_duration_in_seconds (int): The maximum job duration that a job
4692+
can run for. Will be used for `Advanced` jobs.
4693+
job_type (str): The type of job being run. Must either be `Default` or `Advanced`.
4694+
framework_version (str): The framework version of the Image URI.
4695+
nearest_model_name (str): The name of a pre-trained machine learning model
4696+
benchmarked by Amazon SageMaker Inference Recommender that matches your model.
4697+
supported_instance_types (List[str]): A list of the instance types that are used
4698+
to generate inferences in real-time.
4699+
endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
4700+
to use for a job. Will be used for `Advanced` jobs.
4701+
traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
4702+
Will be used for `Advanced` jobs.
4703+
stopping_conditions (Dict[str, any]): A set of conditions for stopping a
4704+
recommendation job.
4705+
If any of the conditions are met, the job is automatically stopped.
4706+
Will be used for `Advanced` jobs.
4707+
resource_limit (Dict[str, any]): Defines the resource limit for the job.
4708+
Will be used for `Advanced` jobs.
4709+
Returns:
4710+
Dict[str, Any]: request dictionary for the CreateInferenceRecommendationsJob API
4711+
"""
4712+
4713+
containerConfig = {
4714+
"Domain": "MACHINE_LEARNING",
4715+
"Task": "OTHER",
4716+
"Framework": framework,
4717+
"PayloadConfig": {
4718+
"SamplePayloadUrl": sample_payload_url,
4719+
"SupportedContentTypes": supported_content_types,
4720+
},
4721+
}
4722+
4723+
if framework_version:
4724+
containerConfig["FrameworkVersion"] = framework_version
4725+
if nearest_model_name:
4726+
containerConfig["NearestModelName"] = nearest_model_name
4727+
if supported_instance_types:
4728+
containerConfig["SupportedInstanceTypes"] = supported_instance_types
4729+
4730+
request = {
4731+
"JobName": job_name,
4732+
"JobType": job_type,
4733+
"RoleArn": role,
4734+
"InputConfig": {
4735+
"ContainerConfig": containerConfig,
4736+
"ModelPackageVersionArn": model_package_version_arn,
4737+
},
4738+
}
4739+
4740+
if job_description:
4741+
request["JobDescription"] = job_description
4742+
if job_duration_in_seconds:
4743+
request["InputConfig"]["JobDurationInSeconds"] = job_duration_in_seconds
4744+
4745+
if job_type == "Advanced":
4746+
if stopping_conditions:
4747+
request["StoppingConditions"] = stopping_conditions
4748+
if resource_limit:
4749+
request["InputConfig"]["ResourceLimit"] = resource_limit
4750+
if traffic_pattern:
4751+
request["InputConfig"]["TrafficPattern"] = traffic_pattern
4752+
if endpoint_configurations:
4753+
request["InputConfig"]["EndpointConfigurations"] = endpoint_configurations
4754+
4755+
return request
4756+
4757+
def create_inference_recommendations_job(
4758+
self,
4759+
role: str,
4760+
sample_payload_url: str,
4761+
supported_content_types: List[str],
4762+
job_name: str = None,
4763+
job_type: str = "Default",
4764+
model_package_version_arn: str = None,
4765+
job_duration_in_seconds: int = None,
4766+
nearest_model_name: str = None,
4767+
supported_instance_types: List[str] = None,
4768+
framework: str = None,
4769+
framework_version: str = None,
4770+
endpoint_configurations: List[Dict[str, any]] = None,
4771+
traffic_pattern: Dict[str, any] = None,
4772+
stopping_conditions: Dict[str, any] = None,
4773+
resource_limit: Dict[str, any] = None,
4774+
):
4775+
"""Creates an Inference Recommendations Job
4776+
4777+
Args:
4778+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
4779+
jobs and APIs that create Amazon SageMaker endpoints use this role to access
4780+
training data and model artifacts.
4781+
You must grant sufficient permissions to this role.
4782+
sample_payload_url (str): The S3 path where the sample payload is stored.
4783+
supported_content_types (List[str]): The supported MIME types for the input data.
4784+
model_package_version_arn (str): The Amazon Resource Name (ARN) of a
4785+
versioned model package.
4786+
job_name (str): The name of the job being run.
4787+
job_type (str): The type of job being run. Must either be `Default` or `Advanced`.
4788+
job_duration_in_seconds (int): The maximum job duration that a job
4789+
can run for. Will be used for `Advanced` jobs.
4790+
nearest_model_name (str): The name of a pre-trained machine learning model
4791+
benchmarked by Amazon SageMaker Inference Recommender that matches your model.
4792+
supported_instance_types (List[str]): A list of the instance types that are used
4793+
to generate inferences in real-time.
4794+
framework (str): The machine learning framework of the Image URI.
4795+
framework_version (str): The framework version of the Image URI.
4796+
endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
4797+
to use for a job. Will be used for `Advanced` jobs.
4798+
traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
4799+
Will be used for `Advanced` jobs.
4800+
stopping_conditions (Dict[str, any]): A set of conditions for stopping a
4801+
recommendation job.
4802+
If any of the conditions are met, the job is automatically stopped.
4803+
Will be used for `Advanced` jobs.
4804+
resource_limit (Dict[str, any]): Defines the resource limit for the job.
4805+
Will be used for `Advanced` jobs.
4806+
Returns:
4807+
str: The name of the job created. In the form of `SMPYTHONSDK-<timestamp>`
4808+
"""
4809+
4810+
if not job_name:
4811+
job_name = "SMPYTHONSDK-" + str(round(time.time()))
4812+
job_description = "#python-sdk-create"
4813+
4814+
create_inference_recommendations_job_request = (
4815+
self._create_inference_recommendations_job_request(
4816+
role=role,
4817+
model_package_version_arn=model_package_version_arn,
4818+
job_name=job_name,
4819+
job_type=job_type,
4820+
job_duration_in_seconds=job_duration_in_seconds,
4821+
job_description=job_description,
4822+
framework=framework,
4823+
framework_version=framework_version,
4824+
nearest_model_name=nearest_model_name,
4825+
sample_payload_url=sample_payload_url,
4826+
supported_content_types=supported_content_types,
4827+
supported_instance_types=supported_instance_types,
4828+
endpoint_configurations=endpoint_configurations,
4829+
traffic_pattern=traffic_pattern,
4830+
stopping_conditions=stopping_conditions,
4831+
resource_limit=resource_limit,
4832+
)
4833+
)
4834+
4835+
def submit(request):
4836+
LOGGER.info("Creating Inference Recommendations job with name: %s", job_name)
4837+
LOGGER.debug("process request: %s", json.dumps(request, indent=4))
4838+
self.sagemaker_client.create_inference_recommendations_job(**request)
4839+
4840+
self._intercept_create_request(
4841+
create_inference_recommendations_job_request,
4842+
submit,
4843+
self.create_inference_recommendations_job.__name__,
4844+
)
4845+
return job_name
4846+
46584847

46594848
def get_model_package_args(
46604849
content_types,

tests/unit/test_session.py

+175
Original file line numberDiff line numberDiff line change
@@ -2937,3 +2937,178 @@ def test_wait_for_athena_query(query_execution, sagemaker_session):
29372937
query_execution.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}}
29382938
sagemaker_session.wait_for_athena_query(query_execution_id="query_id")
29392939
assert query_execution.called_with(query_execution_id="query_id")
2940+
2941+
2942+
IR_USER_JOB_NAME = "custom-job-name"
2943+
IR_JOB_NAME = "SMPYTHONSDK-1234567891"
2944+
IR_ADVANCED_JOB = "Advanced"
2945+
IR_ROLE_ARN = "arn:aws:iam::123456789123:role/service-role/AmazonSageMaker-ExecutionRole-UnitTest"
2946+
IR_SAMPLE_PAYLOAD_URL = "s3://sagemaker-us-west-2-123456789123/payload/payload.tar.gz"
2947+
IR_SUPPORTED_CONTENT_TYPES = ["text/csv"]
2948+
IR_MODEL_PACKAGE_VERSION_ARN = (
2949+
"arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
2950+
)
2951+
IR_NEAREST_MODEL_NAME = "xgboost"
2952+
IR_SUPPORTED_INSTANCE_TYPES = ["ml.c5.xlarge", "ml.c5.2xlarge"]
2953+
IR_FRAMEWORK = "XGBOOST"
2954+
IR_FRAMEWORK_VERSION = "1.2.0"
2955+
IR_NEAREST_MODEL_NAME = "xgboost"
2956+
IR_JOB_DURATION_IN_SECONDS = 7200
2957+
IR_ENDPOINT_CONFIGURATIONS = [
2958+
{
2959+
"EnvironmentParameterRanges": {
2960+
"CategoricalParameterRanges": [{"Name": "OMP_NUM_THREADS", "Value": ["2", "4", "10"]}]
2961+
},
2962+
"InferenceSpecificationName": "unit-test-specification",
2963+
"InstanceType": "ml.c5.xlarge",
2964+
}
2965+
]
2966+
IR_TRAFFIC_PATTERN = {
2967+
"Phases": [{"DurationInSeconds": 120, "InitialNumberOfUsers": 1, "SpawnRate": 1}],
2968+
"TrafficType": "PHASES",
2969+
}
2970+
IR_STOPPING_CONDITIONS = {
2971+
"MaxInvocations": 300,
2972+
"ModelLatencyThresholds": [{"Percentile": "P95", "ValueInMilliseconds": 100}],
2973+
}
2974+
IR_RESOURCE_LIMIT = {"MaxNumberOfTests": 10, "MaxParallelOfTests": 1}
2975+
2976+
2977+
def create_inference_recommendations_job_default_happy_response():
2978+
return {
2979+
"JobName": IR_USER_JOB_NAME,
2980+
"JobType": "Default",
2981+
"RoleArn": IR_ROLE_ARN,
2982+
"InputConfig": {
2983+
"ContainerConfig": {
2984+
"Domain": "MACHINE_LEARNING",
2985+
"Task": "OTHER",
2986+
"Framework": IR_FRAMEWORK,
2987+
"PayloadConfig": {
2988+
"SamplePayloadUrl": IR_SAMPLE_PAYLOAD_URL,
2989+
"SupportedContentTypes": IR_SUPPORTED_CONTENT_TYPES,
2990+
},
2991+
"FrameworkVersion": IR_FRAMEWORK_VERSION,
2992+
"NearestModelName": IR_NEAREST_MODEL_NAME,
2993+
"SupportedInstanceTypes": IR_SUPPORTED_INSTANCE_TYPES,
2994+
},
2995+
"ModelPackageVersionArn": IR_MODEL_PACKAGE_VERSION_ARN,
2996+
},
2997+
"JobDescription": "#python-sdk-create",
2998+
}
2999+
3000+
3001+
def create_inference_recommendations_job_advanced_happy_response():
3002+
base_advanced_job_response = create_inference_recommendations_job_default_happy_response()
3003+
3004+
base_advanced_job_response["JobName"] = IR_JOB_NAME
3005+
base_advanced_job_response["JobType"] = IR_ADVANCED_JOB
3006+
base_advanced_job_response["StoppingConditions"] = IR_STOPPING_CONDITIONS
3007+
base_advanced_job_response["InputConfig"]["JobDurationInSeconds"] = IR_JOB_DURATION_IN_SECONDS
3008+
base_advanced_job_response["InputConfig"]["EndpointConfigurations"] = IR_ENDPOINT_CONFIGURATIONS
3009+
base_advanced_job_response["InputConfig"]["TrafficPattern"] = IR_TRAFFIC_PATTERN
3010+
base_advanced_job_response["InputConfig"]["ResourceLimit"] = IR_RESOURCE_LIMIT
3011+
3012+
return base_advanced_job_response
3013+
3014+
3015+
def test_create_inference_recommendations_job_default_happy(sagemaker_session):
3016+
job_name = sagemaker_session.create_inference_recommendations_job(
3017+
role=IR_ROLE_ARN,
3018+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
3019+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
3020+
model_package_version_arn=IR_MODEL_PACKAGE_VERSION_ARN,
3021+
framework=IR_FRAMEWORK,
3022+
framework_version=IR_FRAMEWORK_VERSION,
3023+
nearest_model_name=IR_NEAREST_MODEL_NAME,
3024+
supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES,
3025+
job_name=IR_USER_JOB_NAME,
3026+
)
3027+
3028+
sagemaker_session.sagemaker_client.create_inference_recommendations_job.assert_called_with(
3029+
**create_inference_recommendations_job_default_happy_response()
3030+
)
3031+
3032+
assert IR_USER_JOB_NAME == job_name
3033+
3034+
3035+
@patch("time.time", MagicMock(return_value=1234567891))
3036+
def test_create_inference_recommendations_job_advanced_happy(sagemaker_session):
3037+
job_name = sagemaker_session.create_inference_recommendations_job(
3038+
role=IR_ROLE_ARN,
3039+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
3040+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
3041+
model_package_version_arn=IR_MODEL_PACKAGE_VERSION_ARN,
3042+
framework=IR_FRAMEWORK,
3043+
framework_version=IR_FRAMEWORK_VERSION,
3044+
nearest_model_name=IR_NEAREST_MODEL_NAME,
3045+
supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES,
3046+
endpoint_configurations=IR_ENDPOINT_CONFIGURATIONS,
3047+
traffic_pattern=IR_TRAFFIC_PATTERN,
3048+
stopping_conditions=IR_STOPPING_CONDITIONS,
3049+
resource_limit=IR_RESOURCE_LIMIT,
3050+
job_type=IR_ADVANCED_JOB,
3051+
job_duration_in_seconds=IR_JOB_DURATION_IN_SECONDS,
3052+
)
3053+
3054+
sagemaker_session.sagemaker_client.create_inference_recommendations_job.assert_called_with(
3055+
**create_inference_recommendations_job_advanced_happy_response()
3056+
)
3057+
3058+
assert IR_JOB_NAME == job_name
3059+
3060+
3061+
def test_create_inference_recommendations_job_propogate_validation_exception(sagemaker_session):
3062+
validation_exception_message = (
3063+
"Failed to describe model due to validation failure with following error: test_error"
3064+
)
3065+
3066+
validation_exception = ClientError(
3067+
{"Error": {"Code": "ValidationException", "Message": validation_exception_message}},
3068+
"create_inference_recommendations_job",
3069+
)
3070+
3071+
sagemaker_session.sagemaker_client.create_inference_recommendations_job.side_effect = (
3072+
validation_exception
3073+
)
3074+
3075+
with pytest.raises(ClientError) as error:
3076+
sagemaker_session.create_inference_recommendations_job(
3077+
role=IR_ROLE_ARN,
3078+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
3079+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
3080+
model_package_version_arn=IR_MODEL_PACKAGE_VERSION_ARN,
3081+
framework=IR_FRAMEWORK,
3082+
framework_version=IR_FRAMEWORK_VERSION,
3083+
nearest_model_name=IR_NEAREST_MODEL_NAME,
3084+
supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES,
3085+
)
3086+
3087+
assert "ValidationException" in str(error)
3088+
3089+
3090+
def test_create_inference_recommendations_job_propogate_other_exception(sagemaker_session):
3091+
access_denied_exception_message = "Access is not allowed for the caller."
3092+
3093+
access_denied_exception = ClientError(
3094+
{"Error": {"Code": "AccessDeniedException", "Message": access_denied_exception_message}},
3095+
"create_inference_recommendations_job",
3096+
)
3097+
3098+
sagemaker_session.sagemaker_client.create_inference_recommendations_job.side_effect = (
3099+
access_denied_exception
3100+
)
3101+
3102+
with pytest.raises(ClientError) as error:
3103+
sagemaker_session.create_inference_recommendations_job(
3104+
role=IR_ROLE_ARN,
3105+
sample_payload_url=IR_SAMPLE_PAYLOAD_URL,
3106+
supported_content_types=IR_SUPPORTED_CONTENT_TYPES,
3107+
model_package_version_arn=IR_MODEL_PACKAGE_VERSION_ARN,
3108+
framework=IR_FRAMEWORK,
3109+
framework_version=IR_FRAMEWORK_VERSION,
3110+
nearest_model_name=IR_NEAREST_MODEL_NAME,
3111+
supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES,
3112+
)
3113+
3114+
assert "AccessDeniedException" in str(error)

0 commit comments

Comments
 (0)