Skip to content

Commit cec05a5

Browse files
Merge branch 'dev' into inference-url
2 parents d28b89f + 5ec2ff4 commit cec05a5

File tree

4 files changed

+88
-20
lines changed

4 files changed

+88
-20
lines changed

src/sagemaker/model.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,26 @@ def register(
178178
"""
179179
if self.model_data is None:
180180
raise ValueError("SageMaker Model Package cannot be created without model data.")
181+
if image_uri is not None:
182+
self.image_uri = image_uri
183+
if model_package_group_name is not None:
184+
container_def = self.prepare_container_def()
185+
else:
186+
container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data}
181187

182188
model_pkg_args = sagemaker.get_model_package_args(
183189
content_types,
184190
response_types,
185191
inference_instances,
186192
transform_instances,
187-
model_package_name,
188-
model_package_group_name,
189-
self.model_data,
190-
image_uri or self.image_uri,
191-
model_metrics,
192-
metadata_properties,
193-
marketplace_cert,
194-
approval_status,
195-
description,
193+
model_package_name=model_package_name,
194+
model_package_group_name=model_package_group_name,
195+
model_metrics=model_metrics,
196+
metadata_properties=metadata_properties,
197+
marketplace_cert=marketplace_cert,
198+
approval_status=approval_status,
199+
description=description,
200+
container_def_list=[container_def],
196201
drift_check_baselines=drift_check_baselines,
197202
)
198203
model_package = self.sagemaker_session.create_model_package_from_containers(

tests/integ/sagemaker/lineage/conftest.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
from tests.integ.sagemaker.lineage.helpers import name, names
3737

3838
SLEEP_TIME_SECONDS = 1
39-
STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline14"
40-
STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint14"
39+
STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline15"
40+
STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint15"
4141

4242

4343
@pytest.fixture
@@ -518,6 +518,13 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
518518
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
519519
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
520520

521+
if endpoint_arn is None:
522+
_deploy_static_endpoint(
523+
execution_arn=static_pipeline_execution_arn,
524+
sagemaker_session=sagemaker_session,
525+
)
526+
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
527+
521528
contexts = sagemaker_session.sagemaker_client.list_contexts(SourceUri=endpoint_arn)[
522529
"ContextSummaries"
523530
]
@@ -584,11 +591,17 @@ def static_dataset_artifact(static_model_artifact, sagemaker_session):
584591

585592

586593
def get_endpoint_arn_from_static_pipeline(sagemaker_session):
587-
endpoint_arn = sagemaker_session.sagemaker_client.describe_endpoint(
588-
EndpointName=STATIC_ENDPOINT_NAME
589-
)["EndpointArn"]
594+
try:
595+
endpoint_arn = sagemaker_session.sagemaker_client.describe_endpoint(
596+
EndpointName=STATIC_ENDPOINT_NAME
597+
)["EndpointArn"]
590598

591-
return endpoint_arn
599+
return endpoint_arn
600+
except ClientError as e:
601+
error = e.response["Error"]
602+
if error["Code"] == "ValidationException":
603+
return None
604+
raise e
592605

593606

594607
def get_model_package_arn_from_static_pipeline(pipeline_execution_arn, sagemaker_session):
@@ -654,7 +667,7 @@ def _deploy_static_endpoint(execution_arn, sagemaker_session):
654667
sagemaker_session=sagemaker_session,
655668
)
656669
model_package.deploy(1, "ml.t2.medium", endpoint_name=STATIC_ENDPOINT_NAME)
657-
time.sleep(60)
670+
time.sleep(120)
658671
except ClientError as e:
659672
if e.response["Error"]["Code"] == "ValidationException":
660673
print(f"Endpoint {STATIC_ENDPOINT_NAME} already exists. Continuing.")

tests/integ/test_auto_ml.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
import os
1616

1717
import pytest
18-
import tests.integ
19-
from sagemaker import AutoML, CandidateEstimator, AutoMLInput
20-
2118
from botocore.exceptions import ClientError
19+
20+
import tests.integ
21+
from sagemaker import AutoML, AutoMLInput, CandidateEstimator
2222
from sagemaker.utils import unique_name_from_base
23-
from tests.integ import DATA_DIR, AUTO_ML_DEFAULT_TIMEMOUT_MINUTES, auto_ml_utils
23+
from tests.integ import AUTO_ML_DEFAULT_TIMEMOUT_MINUTES, DATA_DIR, auto_ml_utils
2424
from tests.integ.timeout import timeout
2525

2626
ROLE = "SageMakerRole"
@@ -169,6 +169,7 @@ def test_auto_ml_describe_auto_ml_job(sagemaker_session):
169169
}
170170
},
171171
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
172+
"ContentType": "text/csv;header=present",
172173
}
173174
]
174175
expected_default_output_config = {
@@ -205,6 +206,7 @@ def test_auto_ml_attach(sagemaker_session):
205206
}
206207
},
207208
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
209+
"ContentType": "text/csv;header=present",
208210
}
209211
]
210212
expected_default_output_config = {

tests/integ/test_mxnet.py

+48
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,54 @@ def test_register_model_package(
231231
sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package_name)
232232

233233

234+
def test_register_model_package_versioned(
235+
mxnet_training_job,
236+
sagemaker_session,
237+
mxnet_inference_latest_version,
238+
mxnet_inference_latest_py_version,
239+
cpu_instance_type,
240+
):
241+
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
242+
243+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
244+
desc = sagemaker_session.sagemaker_client.describe_training_job(
245+
TrainingJobName=mxnet_training_job
246+
)
247+
model_package_group_name = "register-model-package-{}".format(sagemaker_timestamp())
248+
sagemaker_session.sagemaker_client.create_model_package_group(
249+
ModelPackageGroupName=model_package_group_name
250+
)
251+
model_data = desc["ModelArtifacts"]["S3ModelArtifacts"]
252+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
253+
model = MXNetModel(
254+
model_data,
255+
"SageMakerRole",
256+
entry_point=script_path,
257+
py_version=mxnet_inference_latest_py_version,
258+
sagemaker_session=sagemaker_session,
259+
framework_version=mxnet_inference_latest_version,
260+
)
261+
model_pkg = model.register(
262+
content_types=["application/json"],
263+
response_types=["application/json"],
264+
inference_instances=["ml.m5.large"],
265+
transform_instances=["ml.m5.large"],
266+
model_package_group_name=model_package_group_name,
267+
approval_status="Approved",
268+
)
269+
assert isinstance(model_pkg, ModelPackage)
270+
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
271+
data = numpy.zeros(shape=(1, 1, 28, 28))
272+
result = predictor.predict(data)
273+
assert result is not None
274+
sagemaker_session.sagemaker_client.delete_model_package(
275+
ModelPackageName=model_pkg.model_package_arn
276+
)
277+
sagemaker_session.sagemaker_client.delete_model_package_group(
278+
ModelPackageGroupName=model_package_group_name
279+
)
280+
281+
234282
def test_deploy_model_with_tags_and_kms(
235283
mxnet_training_job,
236284
sagemaker_session,

0 commit comments

Comments
 (0)