diff --git a/tests/integ/sagemaker/lineage/test_association.py b/tests/integ/sagemaker/lineage/test_association.py index 09a84f0434..369a79bdd2 100644 --- a/tests/integ/sagemaker/lineage/test_association.py +++ b/tests/integ/sagemaker/lineage/test_association.py @@ -14,18 +14,16 @@ from __future__ import absolute_import import datetime +import time -import pytest from sagemaker.lineage import association -@pytest.mark.skip(reason="Not in CMH yet") def test_create_delete(association_obj): # fixture does create and then delete, this test ensures it happens at least once assert association_obj.source_arn -@pytest.mark.skip(reason="Not in CMH yet") def test_list(association_objs, sagemaker_session): slack = datetime.timedelta(minutes=1) now = datetime.datetime.now(datetime.timezone.utc) @@ -35,13 +33,16 @@ def test_list(association_objs, sagemaker_session): for sort_order in ["Ascending", "Descending"]: association_keys_listed = [] + source_arn = [assoc_obj.source_arn for assoc_obj in association_objs][0] listed = association.Association.list( + source_arn=source_arn, created_after=now - slack, created_before=now + slack, sort_by="CreationTime", sort_order=sort_order, sagemaker_session=sagemaker_session, ) + for assoc in listed: key = assoc.source_arn + ":" + assoc.destination_arn if key in association_keys: @@ -49,13 +50,36 @@ def test_list(association_objs, sagemaker_session): if sort_order == "Descending": association_names_listed = association_keys_listed[::-1] - assert association_keys == association_names_listed + assert association_keys[::-1] == association_names_listed # sanity check assert association_keys_listed -@pytest.mark.skip(reason="Not in CMH yet") def test_set_tag(association_obj, sagemaker_session): tag = {"Key": "foo", "Value": "bar"} association_obj.set_tag(tag) - assert association_obj.get_tag() == tag + + while True: + actual_tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=association_obj.source_arn + )["Tags"] + if actual_tags: + break + time.sleep(1) + assert len(actual_tags) == 1 + assert actual_tags[0] == tag + + +def test_tags(association_obj, sagemaker_session): + tags = [{"Key": "foo1", "Value": "bar1"}] + association_obj.set_tags(tags) + + while True: + actual_tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=association_obj.source_arn + )["Tags"] + if actual_tags: + break + time.sleep(1) + assert len(actual_tags) == 1 + assert actual_tags == tags diff --git a/tests/integ/sagemaker/lineage/test_dataset_artifact.py b/tests/integ/sagemaker/lineage/test_dataset_artifact.py index 33b4d43502..db5b4fb097 100644 --- a/tests/integ/sagemaker/lineage/test_dataset_artifact.py +++ b/tests/integ/sagemaker/lineage/test_dataset_artifact.py @@ -15,7 +15,10 @@ def test_trained_models( - sagemaker_session, dataset_artifact_associated_models, trial_component_obj, model_artifact_obj1 + sagemaker_session, + dataset_artifact_associated_models, + trial_component_obj, + model_artifact_obj1, ): model_list = dataset_artifact_associated_models.trained_models() diff --git a/tests/integ/sagemaker/lineage/test_endpoint_context.py b/tests/integ/sagemaker/lineage/test_endpoint_context.py index 6eaa1e949b..e1c154b035 100644 --- a/tests/integ/sagemaker/lineage/test_endpoint_context.py +++ b/tests/integ/sagemaker/lineage/test_endpoint_context.py @@ -15,7 +15,10 @@ def test_model( - endpoint_context_associate_with_model, model_obj, endpoint_action_obj, sagemaker_session + endpoint_context_associate_with_model, + model_obj, + endpoint_action_obj, + sagemaker_session, ): model_list = endpoint_context_associate_with_model.models() for model in model_list: