Skip to content

Commit 6db5192

Browse files
committed
fix test
1 parent e3f2fb0 commit 6db5192

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

tests/integ/test_tf.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import pytest
2020

21+
from packaging.version import Version
22+
2123
from sagemaker.tensorflow import TensorFlow, TensorFlowProcessor
2224
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp
2325

@@ -206,15 +208,31 @@ def test_mnist_distributed(
206208

207209

208210
@pytest.mark.slow_test
209-
def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version):
211+
def test_mnist_async(
212+
sagemaker_session,
213+
cpu_instance_type,
214+
tf_full_version,
215+
tensorflow_training_latest_version,
216+
tf_full_py_version
217+
):
218+
219+
# Use the latest patch version for training, if available
220+
tf_full_v = Version(tf_full_version)
221+
tf_training_latest_v = Version(tensorflow_training_latest_version)
222+
223+
if (tf_full_v.major, tf_full_v.minor) == (tf_training_latest_v.major, tf_training_latest_v.minor):
224+
tf_fw_version = tensorflow_training_latest_version
225+
else:
226+
tf_fw_version = tf_full_version
227+
210228
estimator = TensorFlow(
211229
entry_point=SCRIPT,
212230
source_dir=MNIST_RESOURCE_PATH,
213231
role=ROLE,
214232
instance_count=1,
215233
instance_type="ml.c5.4xlarge",
216234
sagemaker_session=sagemaker_session,
217-
framework_version=tf_full_version,
235+
framework_version=tf_fw_version,
218236
py_version=tf_full_py_version,
219237
tags=TAGS,
220238
)

0 commit comments

Comments
 (0)