Skip to content

fix: Fix lineage query integ tests #2823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions tests/integ/sagemaker/lineage/test_dataset_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
# language governing permissions and limitations under the License.
"""This module contains code to test SageMaker ``DatasetArtifact``"""
from __future__ import absolute_import
from tests.integ.sagemaker.lineage.helpers import traverse_graph_forward


def test_trained_models(
sagemaker_session,
dataset_artifact_associated_models,
trial_component_obj,
model_artifact_obj1,
Expand All @@ -31,20 +29,9 @@ def test_trained_models(

def test_endpoint_contexts(
static_dataset_artifact,
sagemaker_session,
):
contexts_from_query = static_dataset_artifact.endpoint_contexts()

associations_from_api = traverse_graph_forward(
static_dataset_artifact.artifact_arn, sagemaker_session=sagemaker_session
)

assert len(contexts_from_query) > 0
for context in contexts_from_query:
# assert that the contexts from the query
# appear in the association list from the lineage API
assert any(
x
for x in associations_from_api
if x["DestinationArn"] == context.context_arn and x["DestinationType"] == "Endpoint"
)
assert context.context_type == "Endpoint"
25 changes: 3 additions & 22 deletions tests/integ/sagemaker/lineage/test_endpoint_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,9 @@
# language governing permissions and limitations under the License.
"""This module contains code to test SageMaker ``Contexts``"""
from __future__ import absolute_import
from tests.integ.sagemaker.lineage.helpers import traverse_graph_back


def test_model(
endpoint_context_associate_with_model,
model_obj,
endpoint_action_obj,
sagemaker_session,
):
def test_model(endpoint_context_associate_with_model, model_obj, endpoint_action_obj):
model_list = endpoint_context_associate_with_model.models()
for model in model_list:
assert model.source_arn == endpoint_action_obj.action_arn
Expand All @@ -29,25 +23,12 @@ def test_model(
assert model.destination_type == "Model"


def test_dataset_artifacts(
static_endpoint_context,
sagemaker_session,
):
def test_dataset_artifacts(static_endpoint_context):
artifacts_from_query = static_endpoint_context.dataset_artifacts()

associations_from_api = traverse_graph_back(
static_endpoint_context.context_arn, sagemaker_session=sagemaker_session
)

assert len(artifacts_from_query) > 0
for artifact in artifacts_from_query:
# assert that the artifacts from the query
# appear in the association list from the lineage API
assert any(
x
for x in associations_from_api
if x["SourceArn"] == artifact.artifact_arn and x["SourceType"] == "DataSet"
)
assert artifact.artifact_type == "DataSet"


def test_training_job_arns(
Expand Down
28 changes: 2 additions & 26 deletions tests/integ/sagemaker/lineage/test_model_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
# language governing permissions and limitations under the License.
"""This module contains code to test SageMaker ``DatasetArtifact``"""
from __future__ import absolute_import
from tests.integ.sagemaker.lineage.helpers import traverse_graph_forward, traverse_graph_back


def test_endpoints(
sagemaker_session,
model_artifact_associated_endpoints,
endpoint_deployment_action_obj,
endpoint_context_obj,
Expand All @@ -32,44 +30,22 @@ def test_endpoints(

def test_endpoint_contexts(
static_model_artifact,
sagemaker_session,
):
contexts_from_query = static_model_artifact.endpoint_contexts()

associations_from_api = traverse_graph_forward(
static_model_artifact.artifact_arn, sagemaker_session=sagemaker_session
)

assert len(contexts_from_query) > 0
for context in contexts_from_query:
# assert that the contexts from the query
# appear in the association list from the lineage API
assert any(
x
for x in associations_from_api
if x["DestinationArn"] == context.context_arn and x["DestinationType"] == "Endpoint"
)
assert context.context_type == "Endpoint"


def test_dataset_artifacts(
static_model_artifact,
sagemaker_session,
):
artifacts_from_query = static_model_artifact.dataset_artifacts()

associations_from_api = traverse_graph_back(
static_model_artifact.artifact_arn, sagemaker_session=sagemaker_session
)

assert len(artifacts_from_query) > 0
for artifact in artifacts_from_query:
# assert that the artifacts from the query
# appear in the association list from the lineage API
assert any(
x
for x in associations_from_api
if x["SourceArn"] == artifact.artifact_arn and x["SourceType"] == "DataSet"
)
assert artifact.artifact_type == "DataSet"


def test_training_job_arns(
Expand Down