Skip to content

Commit 1595bf0

Browse files
committed
Fix List Associations integ tests
1 parent c6effe5 commit 1595bf0

File tree

4 files changed

+40
-10
lines changed

4 files changed

+40
-10
lines changed

tests/integ/sagemaker/lineage/test_artifact.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def test_downstream_trials(trial_associated_artifact, trial_obj, sagemaker_sessi
8181
# allow trial components to index, 30 seconds max
8282
for i in range(3):
8383
time.sleep(10)
84-
trials = trial_associated_artifact.downstream_trials(sagemaker_session=sagemaker_session)
84+
trials = trial_associated_artifact.downstream_trials(
85+
sagemaker_session=sagemaker_session
86+
)
8587
if len(trials) > 0:
8688
break
8789

tests/integ/sagemaker/lineage/test_association.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,47 +15,69 @@
1515

1616
import datetime
1717

18-
import pytest
1918
from sagemaker.lineage import association
2019

2120

22-
@pytest.mark.skip(reason="Not in CMH yet")
2321
def test_create_delete(association_obj):
2422
# fixture does create and then delete, this test ensures it happens at least once
2523
assert association_obj.source_arn
2624

2725

28-
@pytest.mark.skip(reason="Not in CMH yet")
2926
def test_list(association_objs, sagemaker_session):
3027
slack = datetime.timedelta(minutes=1)
3128
now = datetime.datetime.now(datetime.timezone.utc)
3229
association_keys = [
33-
assoc_obj.source_arn + ":" + assoc_obj.destination_arn for assoc_obj in association_objs
30+
assoc_obj.source_arn + ":" + assoc_obj.destination_arn
31+
for assoc_obj in association_objs
3432
]
3533

3634
for sort_order in ["Ascending", "Descending"]:
3735
association_keys_listed = []
36+
source_arn = [assoc_obj.source_arn for assoc_obj in association_objs][0]
3837
listed = association.Association.list(
38+
source_arn=source_arn,
3939
created_after=now - slack,
4040
created_before=now + slack,
4141
sort_by="CreationTime",
4242
sort_order=sort_order,
4343
sagemaker_session=sagemaker_session,
4444
)
45+
4546
for assoc in listed:
4647
key = assoc.source_arn + ":" + assoc.destination_arn
4748
if key in association_keys:
4849
association_keys_listed.append(key)
4950

5051
if sort_order == "Descending":
5152
association_names_listed = association_keys_listed[::-1]
52-
assert association_keys == association_names_listed
53+
assert association_keys[::-1] == association_names_listed
5354
# sanity check
5455
assert association_keys_listed
5556

5657

57-
@pytest.mark.skip(reason="Not in CMH yet")
5858
def test_set_tag(association_obj, sagemaker_session):
5959
tag = {"Key": "foo", "Value": "bar"}
6060
association_obj.set_tag(tag)
61-
assert association_obj.get_tag() == tag
61+
62+
while True:
63+
actual_tags = sagemaker_session.sagemaker_client.list_tags(
64+
ResourceArn=association_obj.source_arn
65+
)["Tags"]
66+
if actual_tags:
67+
break
68+
assert len(actual_tags) == 1
69+
assert actual_tags[0] == tag
70+
71+
72+
def test_tags(association_obj, sagemaker_session):
73+
tags = [{"Key": "foo1", "Value": "bar1"}]
74+
association_obj.set_tags(tags)
75+
76+
while True:
77+
actual_tags = sagemaker_session.sagemaker_client.list_tags(
78+
ResourceArn=association_obj.source_arn
79+
)["Tags"]
80+
if actual_tags:
81+
break
82+
assert len(actual_tags) == 1
83+
assert actual_tags == tags

tests/integ/sagemaker/lineage/test_dataset_artifact.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515

1616

1717
def test_trained_models(
18-
sagemaker_session, dataset_artifact_associated_models, trial_component_obj, model_artifact_obj1
18+
sagemaker_session,
19+
dataset_artifact_associated_models,
20+
trial_component_obj,
21+
model_artifact_obj1,
1922
):
2023

2124
model_list = dataset_artifact_associated_models.trained_models()

tests/integ/sagemaker/lineage/test_endpoint_context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515

1616

1717
def test_model(
18-
endpoint_context_associate_with_model, model_obj, endpoint_action_obj, sagemaker_session
18+
endpoint_context_associate_with_model,
19+
model_obj,
20+
endpoint_action_obj,
21+
sagemaker_session,
1922
):
2023
model_list = endpoint_context_associate_with_model.models()
2124
for model in model_list:

0 commit comments

Comments
 (0)