Skip to content

Commit 4570aa6

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
fix: Pop out ModelPackageName from pipeline definition (#3472)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 3f6ea88 commit 4570aa6

File tree

5 files changed

+150
-139
lines changed

5 files changed

+150
-139
lines changed

src/sagemaker/workflow/_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Scrapper utilities to support repacking of models."""
1414
from __future__ import absolute_import
1515

16+
import logging
1617
import os
1718
import shutil
1819
import tarfile
@@ -37,6 +38,8 @@
3738
if TYPE_CHECKING:
3839
from sagemaker.workflow.step_collections import StepCollection
3940

41+
logger = logging.getLogger(__name__)
42+
4043
FRAMEWORK_VERSION = "0.23-1"
4144
INSTANCE_TYPE = "ml.m5.large"
4245
REPACK_SCRIPT = "_repack_model.py"
@@ -479,10 +482,19 @@ def arguments(self) -> RequestType:
479482

480483
request_dict = get_create_model_package_request(**model_package_args)
481484
# these are not available in the workflow service and will cause rejection
485+
warn_msg_template = (
486+
"Popping out '%s' from the pipeline definition "
487+
"since it will be overridden in pipeline execution time."
488+
)
482489
if "CertifyForMarketplace" in request_dict:
483490
request_dict.pop("CertifyForMarketplace")
491+
logger.warning(warn_msg_template, "CertifyForMarketplace")
484492
if "Description" in request_dict:
485493
request_dict.pop("Description")
494+
logger.warning(warn_msg_template, "Description")
495+
if "ModelPackageName" in request_dict:
496+
request_dict.pop("ModelPackageName")
497+
logger.warning(warn_msg_template, "ModelPackageName")
486498

487499
return request_dict
488500

tests/integ/sagemaker/workflow/test_model_steps.py

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen
112112
inference_instances=["ml.m5.xlarge"],
113113
transform_instances=["ml.m5.xlarge"],
114114
description="test-description",
115+
model_package_name="model-pkg-name-will-be-popped-out",
115116
)
116117
step_model_regis = ModelStep(
117118
name="pytorch-register-model",
+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from unittest.mock import Mock, PropertyMock
16+
17+
import pytest
18+
19+
from sagemaker import Session
20+
from sagemaker.workflow.pipeline_context import PipelineSession
21+
22+
REGION = "us-west-2"
23+
BUCKET = "my-bucket"
24+
ROLE = "DummyRole"
25+
IMAGE_URI = "fakeimage"
26+
27+
28+
@pytest.fixture(scope="module")
29+
def client():
30+
"""Mock client.
31+
32+
Considerations when appropriate:
33+
34+
* utilize botocore.stub.Stubber
35+
* separate runtime client from client
36+
"""
37+
client_mock = Mock()
38+
client_mock._client_config.user_agent = (
39+
"Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
40+
)
41+
return client_mock
42+
43+
44+
@pytest.fixture(scope="module")
45+
def boto_session(client):
46+
role_mock = Mock()
47+
type(role_mock).arn = PropertyMock(return_value=ROLE)
48+
49+
resource_mock = Mock()
50+
resource_mock.Role.return_value = role_mock
51+
52+
session_mock = Mock(region_name=REGION)
53+
session_mock.resource.return_value = resource_mock
54+
session_mock.client.return_value = client
55+
56+
return session_mock
57+
58+
59+
@pytest.fixture(scope="module")
60+
def pipeline_session(boto_session, client):
61+
return PipelineSession(
62+
boto_session=boto_session,
63+
sagemaker_client=client,
64+
default_bucket=BUCKET,
65+
)
66+
67+
68+
@pytest.fixture(scope="module")
69+
def sagemaker_session(boto_session, client):
70+
return Session(
71+
boto_session=boto_session,
72+
sagemaker_client=client,
73+
sagemaker_runtime_client=client,
74+
default_bucket=BUCKET,
75+
)

0 commit comments

Comments
 (0)