14
14
from __future__ import absolute_import
15
15
16
16
import logging
17
+ import re
17
18
18
19
from typing import List , Dict , Optional
19
-
20
20
import sagemaker
21
-
22
21
from sagemaker .parameter import CategoricalParameter
23
22
24
23
INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
@@ -101,13 +100,15 @@ def right_size(
101
100
'OMP_NUM_THREADS': CategoricalParameter(['1', '2', '3', '4'])
102
101
}]
103
102
104
- phases (list[Phase]): Specifies the criteria for increasing load
105
- during endpoint load tests. (default: None).
106
- traffic_type (str): Specifies the traffic type that matches the phases. (default: None).
107
- max_invocations (str): defines invocation limit for endpoint load tests (default: None).
108
- model_latency_thresholds (list[ModelLatencyThreshold]): defines the response latency
109
- thresholds for endpoint load tests (default: None).
110
- max_tests (int): restricts how many endpoints are allowed to be
103
+ phases (list[Phase]): Shape of the traffic pattern to use in the load test
104
+ (default: None).
105
+ traffic_type (str): Specifies the traffic pattern type. Currently only supports
106
+ one type 'PHASES' (default: None).
107
+ max_invocations (str): defines the minimum invocations per minute for the endpoint
108
+ to support (default: None).
109
+ model_latency_thresholds (list[ModelLatencyThreshold]): defines the maximum response
110
+ latency for endpoints to support (default: None).
111
+ max_tests (int): restricts how many endpoints in total are allowed to be
111
112
spun up for this job (default: None).
112
113
max_parallel_tests (int): restricts how many concurrent endpoints
113
114
this job is allowed to spin up (default: None).
@@ -122,7 +123,7 @@ def right_size(
122
123
raise ValueError ("right_size() is currently only supported with a registered model" )
123
124
124
125
if not framework and self ._framework ():
125
- framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING .get (self ._framework , framework )
126
+ framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING .get (self ._framework () , framework )
126
127
127
128
framework_version = self ._get_framework_version ()
128
129
@@ -176,7 +177,38 @@ def right_size(
176
177
177
178
return self
178
179
179
- def _check_inference_recommender_args (
180
+ def _update_params (
181
+ self ,
182
+ ** kwargs ,
183
+ ):
184
+ """Check and update params based on inference recommendation id or right size case"""
185
+ instance_type = kwargs ["instance_type" ]
186
+ initial_instance_count = kwargs ["initial_instance_count" ]
187
+ accelerator_type = kwargs ["accelerator_type" ]
188
+ async_inference_config = kwargs ["async_inference_config" ]
189
+ serverless_inference_config = kwargs ["serverless_inference_config" ]
190
+ inference_recommendation_id = kwargs ["inference_recommendation_id" ]
191
+ inference_recommender_job_results = kwargs ["inference_recommender_job_results" ]
192
+ if inference_recommendation_id is not None :
193
+ inference_recommendation = self ._update_params_for_recommendation_id (
194
+ instance_type = instance_type ,
195
+ initial_instance_count = initial_instance_count ,
196
+ accelerator_type = accelerator_type ,
197
+ async_inference_config = async_inference_config ,
198
+ serverless_inference_config = serverless_inference_config ,
199
+ inference_recommendation_id = inference_recommendation_id ,
200
+ )
201
+ elif inference_recommender_job_results is not None :
202
+ inference_recommendation = self ._update_params_for_right_size (
203
+ instance_type ,
204
+ initial_instance_count ,
205
+ accelerator_type ,
206
+ serverless_inference_config ,
207
+ async_inference_config ,
208
+ )
209
+ return inference_recommendation or (instance_type , initial_instance_count )
210
+
211
+ def _update_params_for_right_size (
180
212
self ,
181
213
instance_type = None ,
182
214
initial_instance_count = None ,
@@ -232,6 +264,161 @@ def _check_inference_recommender_args(
232
264
]
233
265
return (instance_type , initial_instance_count )
234
266
267
+ def _update_params_for_recommendation_id (
268
+ self ,
269
+ instance_type ,
270
+ initial_instance_count ,
271
+ accelerator_type ,
272
+ async_inference_config ,
273
+ serverless_inference_config ,
274
+ inference_recommendation_id ,
275
+ ):
276
+ """Update parameters with inference recommendation results.
277
+
278
+ Args:
279
+ instance_type (str): The EC2 instance type to deploy this Model to.
280
+ For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
281
+ serverless inference, then it is required to deploy a model.
282
+ initial_instance_count (int): The initial number of instances to run
283
+ in the ``Endpoint`` created from this ``Model``. If not using
284
+ serverless inference, then it need to be a number larger or equals
285
+ to 1.
286
+ accelerator_type (str): Type of Elastic Inference accelerator to
287
+ deploy this model for model loading and inference, for example,
288
+ 'ml.eia1.medium'. If not specified, no Elastic Inference
289
+ accelerator will be attached to the endpoint. For more
290
+ information:
291
+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
292
+ async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies
293
+ configuration related to async endpoint. Use this configuration when trying
294
+ to create async endpoint and make async inference. If empty config object
295
+ passed through, will use default config to deploy async endpoint. Deploy a
296
+ real-time endpoint if it's None.
297
+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
298
+ Specifies configuration related to serverless endpoint. Use this configuration
299
+ when trying to create serverless endpoint and make serverless inference. If
300
+ empty object passed through, will use pre-defined values in
301
+ ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
302
+ instance based endpoint if it's None.
303
+ inference_recommendation_id (str): The recommendation id which specifies
304
+ the recommendation you picked from inference recommendation job
305
+ results and would like to deploy the model and endpoint with
306
+ recommended parameters.
307
+ Raises:
308
+ ValueError: If arguments combination check failed in these circumstances:
309
+ - If only one of instance type or instance count specified or
310
+ - If recommendation id does not follow the required format or
311
+ - If recommendation id is not valid or
312
+ - If inference recommendation id is specified along with incompatible parameters
313
+ Returns:
314
+ (string, int): instance type and associated instance count from selected
315
+ inference recommendation id if arguments combination check passed.
316
+ """
317
+
318
+ if instance_type is not None and initial_instance_count is not None :
319
+ LOGGER .warning (
320
+ "Both instance_type and initial_instance_count are specified,"
321
+ "overriding the recommendation result."
322
+ )
323
+ return (instance_type , initial_instance_count )
324
+
325
+ # Validate non-compatible parameters with recommendation id
326
+ if bool (instance_type ) != bool (initial_instance_count ):
327
+ raise ValueError (
328
+ "Please either do not specify instance_type and initial_instance_count"
329
+ "since they are in recommendation, or specify both of them if you want"
330
+ "to override the recommendation."
331
+ )
332
+ if accelerator_type is not None :
333
+ raise ValueError ("accelerator_type is not compatible with inference_recommendation_id." )
334
+ if async_inference_config is not None :
335
+ raise ValueError (
336
+ "async_inference_config is not compatible with inference_recommendation_id."
337
+ )
338
+ if serverless_inference_config is not None :
339
+ raise ValueError (
340
+ "serverless_inference_config is not compatible with inference_recommendation_id."
341
+ )
342
+
343
+ # Validate recommendation id
344
+ if not re .match (r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$" , inference_recommendation_id ):
345
+ raise ValueError ("Inference Recommendation id is not valid" )
346
+ recommendation_job_name = inference_recommendation_id .split ("/" )[0 ]
347
+
348
+ sage_client = self .sagemaker_session .sagemaker_client
349
+ recommendation_res = sage_client .describe_inference_recommendations_job (
350
+ JobName = recommendation_job_name
351
+ )
352
+ input_config = recommendation_res ["InputConfig" ]
353
+
354
+ recommendation = next (
355
+ (
356
+ rec
357
+ for rec in recommendation_res ["InferenceRecommendations" ]
358
+ if rec ["RecommendationId" ] == inference_recommendation_id
359
+ ),
360
+ None ,
361
+ )
362
+
363
+ if not recommendation :
364
+ raise ValueError (
365
+ "inference_recommendation_id does not exist in InferenceRecommendations list"
366
+ )
367
+
368
+ model_config = recommendation ["ModelConfiguration" ]
369
+ envs = (
370
+ model_config ["EnvironmentParameters" ]
371
+ if "EnvironmentParameters" in model_config
372
+ else None
373
+ )
374
+ # Update envs
375
+ recommend_envs = {}
376
+ if envs is not None :
377
+ for env in envs :
378
+ recommend_envs [env ["Key" ]] = env ["Value" ]
379
+ self .env .update (recommend_envs )
380
+
381
+ # Update params with non-compilation recommendation results
382
+ if (
383
+ "InferenceSpecificationName" not in model_config
384
+ and "CompilationJobName" not in model_config
385
+ ):
386
+
387
+ if "ModelPackageVersionArn" in input_config :
388
+ modelpkg_res = sage_client .describe_model_package (
389
+ ModelPackageName = input_config ["ModelPackageVersionArn" ]
390
+ )
391
+ self .model_data = modelpkg_res ["InferenceSpecification" ]["Containers" ][0 ][
392
+ "ModelDataUrl"
393
+ ]
394
+ self .image_uri = modelpkg_res ["InferenceSpecification" ]["Containers" ][0 ]["Image" ]
395
+ elif "ModelName" in input_config :
396
+ model_res = sage_client .describe_model (ModelName = input_config ["ModelName" ])
397
+ self .model_data = model_res ["PrimaryContainer" ]["ModelDataUrl" ]
398
+ self .image_uri = model_res ["PrimaryContainer" ]["Image" ]
399
+ else :
400
+ if "InferenceSpecificationName" in model_config :
401
+ modelpkg_res = sage_client .describe_model_package (
402
+ ModelPackageName = input_config ["ModelPackageVersionArn" ]
403
+ )
404
+ self .model_data = modelpkg_res ["AdditionalInferenceSpecificationDefinition" ][
405
+ "Containers"
406
+ ][0 ]["ModelDataUrl" ]
407
+ self .image_uri = modelpkg_res ["AdditionalInferenceSpecificationDefinition" ][
408
+ "Containers"
409
+ ][0 ]["Image" ]
410
+ elif "CompilationJobName" in model_config :
411
+ compilation_res = sage_client .describe_compilation_job (
412
+ CompilationJobName = model_config ["CompilationJobName" ]
413
+ )
414
+ self .model_data = compilation_res ["ModelArtifacts" ]["S3ModelArtifacts" ]
415
+ self .image_uri = compilation_res ["InferenceImage" ]
416
+
417
+ instance_type = recommendation ["EndpointConfiguration" ]["InstanceType" ]
418
+ initial_instance_count = recommendation ["EndpointConfiguration" ]["InitialInstanceCount" ]
419
+
420
+ return (instance_type , initial_instance_count )
421
+
235
422
def _convert_to_endpoint_configurations_json (
236
423
self , hyperparameter_ranges : List [Dict [str , CategoricalParameter ]]
237
424
):
0 commit comments