Skip to content

Commit 9d25fcf

Browse files
authored
feature: Add support for SageMaker lineage queries context (#2830)
1 parent b796a00 commit 9d25fcf

File tree

4 files changed

+100
-0
lines changed

4 files changed

+100
-0
lines changed

src/sagemaker/lineage/context.py

+12
Original file line numberDiff line numberDiff line change
@@ -490,3 +490,15 @@ def pipeline_execution_arn(
490490
return tag["Value"]
491491

492492
return None
493+
494+
495+
class ModelPackageGroup(Context):
496+
"""An Amazon SageMaker model package group context, which is part of a SageMaker lineage."""
497+
498+
def pipeline_execution_arn(self) -> str:
499+
"""Get the ARN for the pipeline execution associated with this model package group (if any).
500+
501+
Returns:
502+
str: A pipeline execution ARN.
503+
"""
504+
return self.properties.get("PipelineExecutionArn")

tests/integ/sagemaker/lineage/conftest.py

+32
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,29 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
667667
)
668668

669669

670+
@pytest.fixture
671+
def static_model_package_group_context(sagemaker_session, static_pipeline_execution_arn):
672+
673+
model_package_group_arn = get_model_package_group_arn_from_static_pipeline(sagemaker_session)
674+
675+
contexts = sagemaker_session.sagemaker_client.list_contexts(SourceUri=model_package_group_arn)[
676+
"ContextSummaries"
677+
]
678+
if len(contexts) != 1:
679+
raise (
680+
Exception(
681+
f"Got an unexpected number of Contexts for \
682+
model package group {STATIC_MODEL_PACKAGE_GROUP_NAME} from pipeline \
683+
execution {static_pipeline_execution_arn}. \
684+
Expected 1 but got {len(contexts)}"
685+
)
686+
)
687+
688+
yield context.ModelPackageGroup.load(
689+
contexts[0]["ContextName"], sagemaker_session=sagemaker_session
690+
)
691+
692+
670693
@pytest.fixture
671694
def static_model_artifact(sagemaker_session, static_pipeline_execution_arn):
672695
model_package_arn = get_model_package_arn_from_static_pipeline(
@@ -745,6 +768,15 @@ def get_endpoint_arn_from_static_pipeline(sagemaker_session):
745768
raise e
746769

747770

771+
def get_model_package_group_arn_from_static_pipeline(sagemaker_session):
772+
static_model_package_group_arn = (
773+
sagemaker_session.sagemaker_client.describe_model_package_group(
774+
ModelPackageGroupName=STATIC_MODEL_PACKAGE_GROUP_NAME
775+
)["ModelPackageGroupArn"]
776+
)
777+
return static_model_package_group_arn
778+
779+
748780
def get_model_package_arn_from_static_pipeline(pipeline_execution_arn, sagemaker_session):
749781
# get the model package ARN from the pipeline
750782
pipeline_execution_steps = sagemaker_session.sagemaker_client.list_pipeline_execution_steps(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code to test SageMaker ``ModelPackageGroup``"""
14+
from __future__ import absolute_import
15+
16+
17+
def test_pipeline_execution_arn(static_model_package_group_context, static_pipeline_execution_arn):
18+
pipeline_execution_arn = static_model_package_group_context.pipeline_execution_arn()
19+
20+
assert pipeline_execution_arn == static_pipeline_execution_arn
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code to test SageMaker ``ModelPackageGroup``"""
14+
from __future__ import absolute_import
15+
16+
import unittest.mock
17+
import pytest
18+
from sagemaker.lineage import context
19+
20+
21+
@pytest.fixture
22+
def sagemaker_session():
23+
return unittest.mock.Mock()
24+
25+
26+
def test_pipeline_execution_arn(sagemaker_session):
27+
obj = context.ModelPackageGroup(
28+
sagemaker_session,
29+
context_name="foo",
30+
description="test-description",
31+
properties={"PipelineExecutionArn": "abcd", "k2": "v2"},
32+
properties_to_remove=["E"],
33+
)
34+
actual_result = obj.pipeline_execution_arn()
35+
expected_result = "abcd"
36+
assert expected_result == actual_result

0 commit comments

Comments
 (0)