File tree 1 file changed +20
-2
lines changed
1 file changed +20
-2
lines changed Original file line number Diff line number Diff line change 18
18
19
19
import pytest
20
20
21
+ from packaging .version import Version
22
+
21
23
from sagemaker .tensorflow import TensorFlow , TensorFlowProcessor
22
24
from sagemaker .utils import unique_name_from_base , sagemaker_timestamp
23
25
@@ -206,15 +208,31 @@ def test_mnist_distributed(
206
208
207
209
208
210
@pytest .mark .slow_test
209
- def test_mnist_async (sagemaker_session , cpu_instance_type , tf_full_version , tf_full_py_version ):
211
+ def test_mnist_async (
212
+ sagemaker_session ,
213
+ cpu_instance_type ,
214
+ tf_full_version ,
215
+ tensorflow_training_latest_version ,
216
+ tf_full_py_version
217
+ ):
218
+
219
+ # Use the latest patch version for training, if available
220
+ tf_full_v = Version (tf_full_version )
221
+ tf_training_latest_v = Version (tensorflow_training_latest_version )
222
+
223
+ if (tf_full_v .major , tf_full_v .minor ) == (tf_training_latest_v .major , tf_training_latest_v .minor ):
224
+ tf_fw_version = tensorflow_training_latest_version
225
+ else :
226
+ tf_fw_version = tf_full_version
227
+
210
228
estimator = TensorFlow (
211
229
entry_point = SCRIPT ,
212
230
source_dir = MNIST_RESOURCE_PATH ,
213
231
role = ROLE ,
214
232
instance_count = 1 ,
215
233
instance_type = "ml.c5.4xlarge" ,
216
234
sagemaker_session = sagemaker_session ,
217
- framework_version = tf_full_version ,
235
+ framework_version = tf_fw_version ,
218
236
py_version = tf_full_py_version ,
219
237
tags = TAGS ,
220
238
)
You can’t perform that action at this time.
0 commit comments