Skip to content

Commit e9f800c

Browse files
committed
change: waiting for training tags to propagate in the test
1 parent 9b9bc5b commit e9f800c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/integ/test_tf_script_mode.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,14 @@ def test_mnist_async(sagemaker_session):
146146
training_job_name = estimator.latest_training_job.name
147147
time.sleep(20)
148148
endpoint_name = training_job_name
149-
model_name = "model-name-1"
150149
_assert_training_job_tags_match(
151150
sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS
152151
)
153152
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
154153
estimator = TensorFlow.attach(
155154
training_job_name=training_job_name, sagemaker_session=sagemaker_session
156155
)
156+
model_name = "model-mnist-async"
157157
predictor = estimator.deploy(
158158
initial_instance_count=1,
159159
instance_type="ml.c4.xlarge",
@@ -215,14 +215,14 @@ def _assert_s3_files_exist(s3_url, files):
215215
raise ValueError("File {} is not found under {}".format(f, s3_url))
216216

217217

218-
def _assert_tags_match(sagemaker_client, resource_arn, tags, retries=1):
218+
def _assert_tags_match(sagemaker_client, resource_arn, tags, retries=15):
219219
actual_tags = None
220220
for _ in range(retries):
221221
actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)["Tags"]
222222
if actual_tags:
223223
break
224224
else:
225-
# endpoint tags might take minutes to propagate. Sleeping.
225+
# endpoint and training tags might take minutes to propagate. Sleeping.
226226
time.sleep(30)
227227
assert actual_tags == tags
228228

@@ -235,7 +235,7 @@ def _assert_model_tags_match(sagemaker_client, model_name, tags):
235235
def _assert_endpoint_tags_match(sagemaker_client, endpoint_name, tags):
236236
endpoint_description = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
237237

238-
_assert_tags_match(sagemaker_client, endpoint_description["EndpointArn"], tags, retries=10)
238+
_assert_tags_match(sagemaker_client, endpoint_description["EndpointArn"], tags)
239239

240240

241241
def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):

0 commit comments

Comments
 (0)