diff --git a/tests/integ/test_tf_script_mode.py b/tests/integ/test_tf_script_mode.py index 204f47ebde..85bdbdbed3 100644 --- a/tests/integ/test_tf_script_mode.py +++ b/tests/integ/test_tf_script_mode.py @@ -146,7 +146,6 @@ def test_mnist_async(sagemaker_session): training_job_name = estimator.latest_training_job.name time.sleep(20) endpoint_name = training_job_name - model_name = "model-name-1" _assert_training_job_tags_match( sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS ) @@ -154,6 +153,7 @@ def test_mnist_async(sagemaker_session): estimator = TensorFlow.attach( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) + model_name = "model-mnist-async" predictor = estimator.deploy( initial_instance_count=1, instance_type="ml.c4.xlarge", @@ -215,14 +215,14 @@ def _assert_s3_files_exist(s3_url, files): raise ValueError("File {} is not found under {}".format(f, s3_url)) -def _assert_tags_match(sagemaker_client, resource_arn, tags, retries=1): +def _assert_tags_match(sagemaker_client, resource_arn, tags, retries=15): actual_tags = None for _ in range(retries): actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)["Tags"] if actual_tags: break else: - # endpoint tags might take minutes to propagate. Sleeping. + # endpoint and training tags might take minutes to propagate. Sleeping. time.sleep(30) assert actual_tags == tags @@ -235,7 +235,7 @@ def _assert_model_tags_match(sagemaker_client, model_name, tags): def _assert_endpoint_tags_match(sagemaker_client, endpoint_name, tags): endpoint_description = sagemaker_client.describe_endpoint(EndpointName=endpoint_name) - _assert_tags_match(sagemaker_client, endpoint_description["EndpointArn"], tags, retries=10) + _assert_tags_match(sagemaker_client, endpoint_description["EndpointArn"], tags) def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):