Skip to content

Commit f98f1bc

Browse files
Merge branch 'master' into fix-processing-image-uri-param
2 parents f8e424f + b2d4744 commit f98f1bc

File tree

18 files changed

+328
-111
lines changed

18 files changed

+328
-111
lines changed

doc/frameworks/xgboost/using_xgboost.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ For information about the SageMaker Python SDK XGBoost classes, see the followin
465465
* :class:`sagemaker.xgboost.estimator.XGBoost`
466466
* :class:`sagemaker.xgboost.model.XGBoostModel`
467467
* :class:`sagemaker.xgboost.model.XGBoostPredictor`
468+
* :class:`sagemaker.xgboost.processing.XGBoostProcessor`
468469

469470
***********************************
470471
SageMaker XGBoost Docker Containers

doc/frameworks/xgboost/xgboost.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,8 @@ The Amazon SageMaker XGBoost open source framework algorithm.
1616
:members:
1717
:undoc-members:
1818
:show-inheritance:
19+
20+
.. autoclass:: sagemaker.xgboost.processing.XGBoostProcessor
21+
:members:
22+
:undoc-members:
23+
:show-inheritance:

src/sagemaker/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,8 +1291,8 @@ def register(
12911291
self,
12921292
content_types,
12931293
response_types,
1294-
inference_instances,
1295-
transform_instances,
1294+
inference_instances=None,
1295+
transform_instances=None,
12961296
image_uri=None,
12971297
model_package_name=None,
12981298
model_package_group_name=None,
@@ -1314,9 +1314,9 @@ def register(
13141314
content_types (list): The supported MIME types for the input data.
13151315
response_types (list): The supported MIME types for the output data.
13161316
inference_instances (list): A list of the instance types that are used to
1317-
generate inferences in real-time.
1317+
generate inferences in real-time (default: None).
13181318
transform_instances (list): A list of the instance types on which a transformation
1319-
job can be run or on which an endpoint can be deployed.
1319+
job can be run or on which an endpoint can be deployed (default: None).
13201320
image_uri (str): The container image uri for Model Package, if not specified,
13211321
Estimator's training container image will be used (default: None).
13221322
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,

src/sagemaker/huggingface/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,8 @@ def register(
293293
self,
294294
content_types,
295295
response_types,
296-
inference_instances,
297-
transform_instances,
296+
inference_instances=None,
297+
transform_instances=None,
298298
model_package_name=None,
299299
model_package_group_name=None,
300300
image_uri=None,
@@ -313,9 +313,9 @@ def register(
313313
content_types (list): The supported MIME types for the input data.
314314
response_types (list): The supported MIME types for the output data.
315315
inference_instances (list): A list of the instance types that are used to
316-
generate inferences in real-time.
316+
generate inferences in real-time (default: None).
317317
transform_instances (list): A list of the instance types on which a transformation
318-
job can be run or on which an endpoint can be deployed.
318+
job can be run or on which an endpoint can be deployed (default: None).
319319
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
320320
using `model_package_name` makes the Model Package un-versioned.
321321
Defaults to ``None``.
@@ -341,7 +341,7 @@ def register(
341341
Returns:
342342
A `sagemaker.model.ModelPackage` instance.
343343
"""
344-
instance_type = inference_instances[0]
344+
instance_type = inference_instances[0] if inference_instances else None
345345
self._init_sagemaker_session_if_does_not_exist(instance_type)
346346

347347
if image_uri:

src/sagemaker/model.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ def register(
296296
self,
297297
content_types,
298298
response_types,
299-
inference_instances,
300-
transform_instances,
299+
inference_instances=None,
300+
transform_instances=None,
301301
model_package_name=None,
302302
model_package_group_name=None,
303303
image_uri=None,
@@ -317,9 +317,9 @@ def register(
317317
content_types (list): The supported MIME types for the input data.
318318
response_types (list): The supported MIME types for the output data.
319319
inference_instances (list): A list of the instance types that are used to
320-
generate inferences in real-time.
320+
generate inferences in real-time (default: None).
321321
transform_instances (list): A list of the instance types on which a transformation
322-
job can be run or on which an endpoint can be deployed.
322+
job can be run or on which an endpoint can be deployed (default: None).
323323
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
324324
using `model_package_name` makes the Model Package un-versioned (default: None).
325325
model_package_group_name (str): Model Package Group name, exclusive to
@@ -341,7 +341,9 @@ def register(
341341
"MACHINE_LEARNING" (default: None).
342342
343343
Returns:
344-
A `sagemaker.model.ModelPackage` instance.
344+
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
345+
in case the Model instance is built with
346+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
345347
"""
346348
if self.model_data is None:
347349
raise ValueError("SageMaker Model Package cannot be created without model data.")
@@ -351,12 +353,11 @@ def register(
351353
container_def = self.prepare_container_def()
352354
else:
353355
container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data}
354-
355356
model_pkg_args = sagemaker.get_model_package_args(
356357
content_types,
357358
response_types,
358-
inference_instances,
359-
transform_instances,
359+
inference_instances=inference_instances,
360+
transform_instances=transform_instances,
360361
model_package_name=model_package_name,
361362
model_package_group_name=model_package_group_name,
362363
model_metrics=model_metrics,
@@ -399,15 +400,22 @@ def create(
399400
attach to an endpoint for model loading and inference, for
400401
example, 'ml.eia1.medium'. If not specified, no Elastic
401402
Inference accelerator will be attached to the endpoint (default: None).
402-
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
403+
serverless_inference_config (ServerlessInferenceConfig):
403404
Specifies configuration related to serverless endpoint. Instance type is
404405
not provided in serverless inference. So this is used to find image URIs
405406
(default: None).
406407
tags (List[Dict[str, str]]): The list of tags to add to
407-
the model (default: None). Example: >>> tags = [{'Key': 'tagname', 'Value':
408-
'tagvalue'}] For more information about tags, see
409-
https://boto3.amazonaws.com/v1/documentation
410-
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
408+
the model (default: None). Example::
409+
410+
tags = [{'Key': 'tagname', 'Value':'tagvalue'}]
411+
412+
For more information about tags, see
413+
`boto3 documentation <https://boto3.amazonaws.com/v1/documentation/\
414+
api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags>`_
415+
416+
Returns:
417+
None or pipeline step arguments in case the Model instance is built with
418+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
411419
"""
412420
# TODO: we should replace _create_sagemaker_model() with create()
413421
self._create_sagemaker_model(

src/sagemaker/mxnet/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def register(
146146
self,
147147
content_types,
148148
response_types,
149-
inference_instances,
150-
transform_instances,
149+
inference_instances=None,
150+
transform_instances=None,
151151
model_package_name=None,
152152
model_package_group_name=None,
153153
image_uri=None,
@@ -166,9 +166,9 @@ def register(
166166
content_types (list): The supported MIME types for the input data.
167167
response_types (list): The supported MIME types for the output data.
168168
inference_instances (list): A list of the instance types that are used to
169-
generate inferences in real-time.
169+
generate inferences in real-time (default: None).
170170
transform_instances (list): A list of the instance types on which a transformation
171-
job can be run or on which an endpoint can be deployed.
171+
job can be run or on which an endpoint can be deployed (default: None).
172172
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
173173
using `model_package_name` makes the Model Package un-versioned (default: None).
174174
model_package_group_name (str): Model Package Group name, exclusive to
@@ -192,7 +192,7 @@ def register(
192192
Returns:
193193
A `sagemaker.model.ModelPackage` instance.
194194
"""
195-
instance_type = inference_instances[0]
195+
instance_type = inference_instances[0] if inference_instances else None
196196
self._init_sagemaker_session_if_does_not_exist(instance_type)
197197

198198
if image_uri:

src/sagemaker/pipeline.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
self.enable_network_isolation = enable_network_isolation
8585
self.endpoint_name = None
8686

87-
def pipeline_container_def(self, instance_type):
87+
def pipeline_container_def(self, instance_type=None):
8888
"""The pipeline definition for deploying this model.
8989
9090
This is the dict created by ``sagemaker.pipeline_container_def()``.
@@ -266,8 +266,8 @@ def register(
266266
self,
267267
content_types: list,
268268
response_types: list,
269-
inference_instances: list,
270-
transform_instances: list,
269+
inference_instances: Optional[list] = None,
270+
transform_instances: Optional[list] = None,
271271
model_package_name: Optional[str] = None,
272272
model_package_group_name: Optional[str] = None,
273273
image_uri: Optional[str] = None,
@@ -286,9 +286,9 @@ def register(
286286
content_types (list): The supported MIME types for the input data.
287287
response_types (list): The supported MIME types for the output data.
288288
inference_instances (list): A list of the instance types that are used to
289-
generate inferences in real-time.
289+
generate inferences in real-time (default: None).
290290
transform_instances (list): A list of the instance types on which a transformation
291-
job can be run or on which an endpoint can be deployed.
291+
job can be run or on which an endpoint can be deployed (default: None).
292292
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
293293
using `model_package_name` makes the Model Package un-versioned (default: None).
294294
model_package_group_name (str): Model Package Group name, exclusive to
@@ -316,18 +316,23 @@ def register(
316316
if model.model_data is None:
317317
raise ValueError("SageMaker Model Package cannot be created without model data.")
318318
if model_package_group_name is not None:
319-
container_def = self.pipeline_container_def(inference_instances[0])
319+
container_def = self.pipeline_container_def(
320+
inference_instances[0] if inference_instances else None
321+
)
320322
else:
321323
container_def = [
322-
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
324+
{
325+
"Image": image_uri or model.image_uri,
326+
"ModelDataUrl": model.model_data,
327+
}
323328
for model in self.models
324329
]
325330

326331
model_pkg_args = sagemaker.get_model_package_args(
327332
content_types,
328333
response_types,
329-
inference_instances,
330-
transform_instances,
334+
inference_instances=inference_instances,
335+
transform_instances=transform_instances,
331336
model_package_name=model_package_name,
332337
model_package_group_name=model_package_group_name,
333338
model_metrics=model_metrics,

src/sagemaker/pytorch/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def register(
147147
self,
148148
content_types,
149149
response_types,
150-
inference_instances,
151-
transform_instances,
150+
inference_instances=None,
151+
transform_instances=None,
152152
model_package_name=None,
153153
model_package_group_name=None,
154154
image_uri=None,
@@ -167,9 +167,9 @@ def register(
167167
content_types (list): The supported MIME types for the input data.
168168
response_types (list): The supported MIME types for the output data.
169169
inference_instances (list): A list of the instance types that are used to
170-
generate inferences in real-time.
170+
generate inferences in real-time (default: None).
171171
transform_instances (list): A list of the instance types on which a transformation
172-
job can be run or on which an endpoint can be deployed.
172+
job can be run or on which an endpoint can be deployed (default: None).
173173
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
174174
using `model_package_name` makes the Model Package un-versioned (default: None).
175175
model_package_group_name (str): Model Package Group name, exclusive to
@@ -193,7 +193,7 @@ def register(
193193
Returns:
194194
A `sagemaker.model.ModelPackage` instance.
195195
"""
196-
instance_type = inference_instances[0]
196+
instance_type = inference_instances[0] if inference_instances else None
197197
self._init_sagemaker_session_if_does_not_exist(instance_type)
198198

199199
if image_uri:

src/sagemaker/session.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -412,29 +412,47 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
412412
bucket = s3.Bucket(name=bucket_name)
413413
if bucket.creation_date is None:
414414
try:
415-
if region == "us-east-1":
416-
# 'us-east-1' cannot be specified because it is the default region:
417-
# https://github.com/boto/boto3/issues/125
418-
s3.create_bucket(Bucket=bucket_name)
419-
else:
420-
s3.create_bucket(
421-
Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}
422-
)
423-
424-
LOGGER.info("Created S3 bucket: %s", bucket_name)
415+
# trying head bucket call
416+
s3.meta.client.head_bucket(Bucket=bucket.name)
425417
except ClientError as e:
418+
# bucket does not exist or forbidden to access
426419
error_code = e.response["Error"]["Code"]
427420
message = e.response["Error"]["Message"]
428421

429-
if error_code == "BucketAlreadyOwnedByYou":
430-
pass
431-
elif (
432-
error_code == "OperationAborted"
433-
and "conflicting conditional operation" in message
434-
):
435-
# If this bucket is already being concurrently created, we don't need to create
436-
# it again.
437-
pass
422+
if error_code == "404" and message == "Not Found":
423+
# bucket does not exist, create one
424+
try:
425+
if region == "us-east-1":
426+
# 'us-east-1' cannot be specified because it is the default region:
427+
# https://github.com/boto/boto3/issues/125
428+
s3.create_bucket(Bucket=bucket_name)
429+
else:
430+
s3.create_bucket(
431+
Bucket=bucket_name,
432+
CreateBucketConfiguration={"LocationConstraint": region},
433+
)
434+
435+
LOGGER.info("Created S3 bucket: %s", bucket_name)
436+
except ClientError as e:
437+
error_code = e.response["Error"]["Code"]
438+
message = e.response["Error"]["Message"]
439+
440+
if (
441+
error_code == "OperationAborted"
442+
and "conflicting conditional operation" in message
443+
):
444+
# If this bucket is already being concurrently created,
445+
# we don't need to create it again.
446+
pass
447+
else:
448+
raise
449+
elif error_code == "403" and message == "Forbidden":
450+
LOGGER.error(
451+
"Bucket %s exists, but access is forbidden. Please try again after "
452+
"adding appropriate access.",
453+
bucket.name,
454+
)
455+
raise
438456
else:
439457
raise
440458

@@ -4206,8 +4224,8 @@ def _intercept_create_request(
42064224
def get_model_package_args(
42074225
content_types,
42084226
response_types,
4209-
inference_instances,
4210-
transform_instances,
4227+
inference_instances=None,
4228+
transform_instances=None,
42114229
model_package_name=None,
42124230
model_package_group_name=None,
42134231
model_data=None,
@@ -4230,9 +4248,9 @@ def get_model_package_args(
42304248
content_types (list): The supported MIME types for the input data.
42314249
response_types (list): The supported MIME types for the output data.
42324250
inference_instances (list): A list of the instance types that are used to
4233-
generate inferences in real-time.
4251+
generate inferences in real-time (default: None).
42344252
transform_instances (list): A list of the instance types on which a transformation
4235-
job can be run or on which an endpoint can be deployed.
4253+
job can be run or on which an endpoint can be deployed (default: None).
42364254
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
42374255
using `model_package_name` makes the Model Package un-versioned (default: None).
42384256
model_package_group_name (str): Model Package Group name, exclusive to
@@ -4377,10 +4395,9 @@ def get_create_model_package_request(
43774395
if domain is not None:
43784396
request_dict["Domain"] = domain
43794397
if containers is not None:
4380-
if not all([content_types, response_types, inference_instances, transform_instances]):
4398+
if not all([content_types, response_types]):
43814399
raise ValueError(
4382-
"content_types, response_types, inference_inferences and transform_instances "
4383-
"must be provided if containers is present."
4400+
"content_types and response_types " "must be provided if containers is present."
43844401
)
43854402
inference_specification = {
43864403
"Containers": containers,

0 commit comments

Comments
 (0)