Skip to content

Commit 56cd49c

Browse files
Merge branch 'master' into fix-spark-processor
2 parents a75c631 + b2d4744 commit 56cd49c

File tree

8 files changed

+109
-55
lines changed

8 files changed

+109
-55
lines changed

doc/frameworks/xgboost/using_xgboost.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ For information about the SageMaker Python SDK XGBoost classes, see the followin
465465
* :class:`sagemaker.xgboost.estimator.XGBoost`
466466
* :class:`sagemaker.xgboost.model.XGBoostModel`
467467
* :class:`sagemaker.xgboost.model.XGBoostPredictor`
468+
* :class:`sagemaker.xgboost.processing.XGBoostProcessor`
468469

469470
***********************************
470471
SageMaker XGBoost Docker Containers

doc/frameworks/xgboost/xgboost.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,8 @@ The Amazon SageMaker XGBoost open source framework algorithm.
1616
:members:
1717
:undoc-members:
1818
:show-inheritance:
19+
20+
.. autoclass:: sagemaker.xgboost.processing.XGBoostProcessor
21+
:members:
22+
:undoc-members:
23+
:show-inheritance:

src/sagemaker/model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,9 @@ def register(
341341
"MACHINE_LEARNING" (default: None).
342342
343343
Returns:
344-
A `sagemaker.model.ModelPackage` instance.
344+
A `sagemaker.model.ModelPackage` instance or pipeline step arguments
345+
in case the Model instance is built with
346+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
345347
"""
346348
if self.model_data is None:
347349
raise ValueError("SageMaker Model Package cannot be created without model data.")
@@ -398,15 +400,22 @@ def create(
398400
attach to an endpoint for model loading and inference, for
399401
example, 'ml.eia1.medium'. If not specified, no Elastic
400402
Inference accelerator will be attached to the endpoint (default: None).
401-
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
403+
serverless_inference_config (ServerlessInferenceConfig):
402404
Specifies configuration related to serverless endpoint. Instance type is
403405
not provided in serverless inference. So this is used to find image URIs
404406
(default: None).
405407
tags (List[Dict[str, str]]): The list of tags to add to
406-
the model (default: None). Example: >>> tags = [{'Key': 'tagname', 'Value':
407-
'tagvalue'}] For more information about tags, see
408-
https://boto3.amazonaws.com/v1/documentation
409-
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
408+
the model (default: None). Example::
409+
410+
tags = [{'Key': 'tagname', 'Value':'tagvalue'}]
411+
412+
For more information about tags, see
413+
`boto3 documentation <https://boto3.amazonaws.com/v1/documentation/\
414+
api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags>`_
415+
416+
Returns:
417+
None or pipeline step arguments in case the Model instance is built with
418+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
410419
"""
411420
# TODO: we should replace _create_sagemaker_model() with create()
412421
self._create_sagemaker_model(

src/sagemaker/session.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -412,29 +412,47 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
412412
bucket = s3.Bucket(name=bucket_name)
413413
if bucket.creation_date is None:
414414
try:
415-
if region == "us-east-1":
416-
# 'us-east-1' cannot be specified because it is the default region:
417-
# https://github.com/boto/boto3/issues/125
418-
s3.create_bucket(Bucket=bucket_name)
419-
else:
420-
s3.create_bucket(
421-
Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}
422-
)
423-
424-
LOGGER.info("Created S3 bucket: %s", bucket_name)
415+
# trying head bucket call
416+
s3.meta.client.head_bucket(Bucket=bucket.name)
425417
except ClientError as e:
418+
# bucket does not exist or forbidden to access
426419
error_code = e.response["Error"]["Code"]
427420
message = e.response["Error"]["Message"]
428421

429-
if error_code == "BucketAlreadyOwnedByYou":
430-
pass
431-
elif (
432-
error_code == "OperationAborted"
433-
and "conflicting conditional operation" in message
434-
):
435-
# If this bucket is already being concurrently created, we don't need to create
436-
# it again.
437-
pass
422+
if error_code == "404" and message == "Not Found":
423+
# bucket does not exist, create one
424+
try:
425+
if region == "us-east-1":
426+
# 'us-east-1' cannot be specified because it is the default region:
427+
# https://github.com/boto/boto3/issues/125
428+
s3.create_bucket(Bucket=bucket_name)
429+
else:
430+
s3.create_bucket(
431+
Bucket=bucket_name,
432+
CreateBucketConfiguration={"LocationConstraint": region},
433+
)
434+
435+
LOGGER.info("Created S3 bucket: %s", bucket_name)
436+
except ClientError as e:
437+
error_code = e.response["Error"]["Code"]
438+
message = e.response["Error"]["Message"]
439+
440+
if (
441+
error_code == "OperationAborted"
442+
and "conflicting conditional operation" in message
443+
):
444+
# If this bucket is already being concurrently created,
445+
# we don't need to create it again.
446+
pass
447+
else:
448+
raise
449+
elif error_code == "403" and message == "Forbidden":
450+
LOGGER.error(
451+
"Bucket %s exists, but access is forbidden. Please try again after "
452+
"adding appropriate access.",
453+
bucket.name,
454+
)
455+
raise
438456
else:
439457
raise
440458

src/sagemaker/workflow/README.rst

Lines changed: 0 additions & 13 deletions
This file was deleted.

src/sagemaker/workflow/model_step.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,17 @@ def __init__(
5050
Args:
5151
name (str): The name of the `ModelStep`. A name is required and must be
5252
unique within a pipeline.
53-
step_args (_ModelStepArguments): The arguments for the `ModelStep` definition.
54-
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
55-
names or `Step` instances or `StepCollection` instances that the first step,
56-
in this `ModelStep` collection, depends on.
53+
step_args (_ModelStepArguments): The arguments for the `ModelStep` definition,
54+
generated by invoking the :func:`~sagemaker.model.Model.register` or
55+
:func:`~sagemaker.model.Model.create`
56+
under the :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. Example::
57+
58+
model = Model(sagemaker_session=PipelineSession())
59+
model_step = ModelStep(step_args=model.register())
60+
61+
depends_on (List[Union[str, Step, StepCollection]]):
62+
A list of `Step` or `StepCollection`
63+
names or `Step` instances or `StepCollection` instances that it depends on.
5764
If a listed `Step` name does not exist, an error is returned (default: None).
5865
retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry
5966
policies for the `ModelStep` (default: None).

src/sagemaker/workflow/step_collections.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import attr
2020

21-
from sagemaker.deprecations import deprecated_class
2221
from sagemaker.estimator import EstimatorBase
2322
from sagemaker.model import Model
2423
from sagemaker import PipelineModel
@@ -262,18 +261,14 @@ def __init__(
262261
warnings.warn(
263262
(
264263
"We are deprecating the use of RegisterModel. "
265-
"Instead, please use the ModelStep, which simply takes in the step arguments "
266-
"generated by model.register(). For more, see: "
264+
"Please use the ModelStep instead. For more, see: "
267265
"https://sagemaker.readthedocs.io/en/stable/"
268266
"amazon_sagemaker_model_building_pipeline.html#model-step"
269267
),
270268
DeprecationWarning,
271269
)
272270

273271

274-
RegisterModel = deprecated_class(RegisterModel, "RegisterModel")
275-
276-
277272
class EstimatorTransformer(StepCollection):
278273
"""Creates a Transformer step collection for workflow."""
279274

tests/unit/test_default_bucket.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pytest
1616
from botocore.exceptions import ClientError
17-
from mock import MagicMock
17+
from mock import MagicMock, patch
1818
import sagemaker
1919

2020
ACCOUNT_ID = "123"
@@ -32,6 +32,11 @@ def sagemaker_session():
3232

3333

3434
def test_default_bucket_s3_create_call(sagemaker_session):
35+
error = ClientError(
36+
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
37+
operation_name="foo",
38+
)
39+
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
3540
bucket_name = sagemaker_session.default_bucket()
3641

3742
create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
@@ -45,6 +50,25 @@ def test_default_bucket_s3_create_call(sagemaker_session):
4550
assert sagemaker_session._default_bucket == bucket_name
4651

4752

53+
def test_default_bucket_s3_needs_access(sagemaker_session):
54+
with patch("logging.Logger.error") as mocked_error_log:
55+
with pytest.raises(ClientError):
56+
error = ClientError(
57+
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
58+
operation_name="foo",
59+
)
60+
sagemaker_session.boto_session.resource(
61+
"s3"
62+
).meta.client.head_bucket.side_effect = error
63+
sagemaker_session.default_bucket()
64+
mocked_error_log.assert_called_once_with(
65+
"Bucket %s exists, but access is forbidden. Please try again after "
66+
"adding appropriate access.",
67+
DEFAULT_BUCKET_NAME,
68+
)
69+
assert sagemaker_session._default_bucket is None
70+
71+
4872
def test_default_already_cached(sagemaker_session):
4973
existing_default = "mydefaultbucket"
5074
sagemaker_session._default_bucket = existing_default
@@ -57,11 +81,9 @@ def test_default_already_cached(sagemaker_session):
5781

5882

5983
def test_default_bucket_exists(sagemaker_session):
60-
error = ClientError(
61-
error_response={"Error": {"Code": "BucketAlreadyOwnedByYou", "Message": "message"}},
62-
operation_name="foo",
63-
)
64-
sagemaker_session.boto_session.resource().create_bucket.side_effect = error
84+
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.return_value = {
85+
"ResponseMetadata": {"RequestId": "xxx", "HTTPStatusCode": 200, "RetryAttempts": 0}
86+
}
6587

6688
bucket_name = sagemaker_session.default_bucket()
6789
assert bucket_name == DEFAULT_BUCKET_NAME
@@ -70,7 +92,7 @@ def test_default_bucket_exists(sagemaker_session):
7092
def test_concurrent_bucket_modification(sagemaker_session):
7193
message = "A conflicting conditional operation is currently in progress against this resource. Please try again"
7294
error = ClientError(
73-
error_response={"Error": {"Code": "BucketAlreadyOwnedByYou", "Message": message}},
95+
error_response={"Error": {"Code": "OperationAborted", "Message": message}},
7496
operation_name="foo",
7597
)
7698
sagemaker_session.boto_session.resource().create_bucket.side_effect = error
@@ -80,6 +102,11 @@ def test_concurrent_bucket_modification(sagemaker_session):
80102

81103

82104
def test_bucket_creation_client_error(sagemaker_session):
105+
error = ClientError(
106+
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
107+
operation_name="foo",
108+
)
109+
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
83110
with pytest.raises(ClientError):
84111
error = ClientError(
85112
error_response={"Error": {"Code": "SomethingWrong", "Message": "message"}},
@@ -92,6 +119,11 @@ def test_bucket_creation_client_error(sagemaker_session):
92119

93120

94121
def test_bucket_creation_other_error(sagemaker_session):
122+
error = ClientError(
123+
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
124+
operation_name="foo",
125+
)
126+
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
95127
with pytest.raises(RuntimeError):
96128
error = RuntimeError()
97129
sagemaker_session.boto_session.resource().create_bucket.side_effect = error

0 commit comments

Comments
 (0)