Skip to content

Commit 1944f4e

Browse files
committed
fix: updates based on PR feedback, fix fixture
1 parent 3db0ab2 commit 1944f4e

File tree

5 files changed

+24
-38
lines changed

5 files changed

+24
-38
lines changed

tests/conftest.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -285,21 +285,6 @@ def sklearn_full_version(request):
285285
return request.config.getoption("--sklearn-full-version")
286286

287287

288-
@pytest.fixture(scope="module")
289-
def tf_latest_version(request, tf_full_version):
290-
return request.param
291-
292-
293-
@pytest.fixture(scope="module")
294-
def tf_latest_py_version():
295-
return "py37"
296-
297-
298-
@pytest.fixture(scope="module")
299-
def tf_latest_serving_version():
300-
return "2.1.0"
301-
302-
303288
@pytest.fixture(scope="module")
304289
def tf_full_version(request):
305290
return request.config.getoption("--tf-full-version")
@@ -323,10 +308,11 @@ def tf_full_py_version(tf_full_version):
323308

324309

325310
@pytest.fixture(scope="module")
326-
def tf_serving_version(tf_full_version, tf_latest_version, tf_latest_serving_version):
327-
if tf_full_version == tf_latest_version:
328-
return tf_latest_serving_version
329-
return tf_full_version
311+
def tf_serving_version(tf_full_version):
312+
full_version = [int(val) for val in tf_full_version.split(".")]
313+
if full_version < [2, 2]:
314+
return tf_full_version
315+
return "2.1.0"
330316

331317

332318
@pytest.fixture(scope="module", params=["1.15.0", "2.0.0"])

tests/integ/test_airflow_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def test_sklearn_airflow_config_uploads_data_source_to_s3(
550550

551551
@pytest.mark.canary_quick
552552
def test_tf_airflow_config_uploads_data_source_to_s3(
553-
sagemaker_session, cpu_instance_type, tf_latest_version, tf_latest_py_version
553+
sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version
554554
):
555555
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
556556
tf = TensorFlow(
@@ -562,8 +562,8 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
562562
train_instance_count=SINGLE_INSTANCE_COUNT,
563563
train_instance_type=cpu_instance_type,
564564
sagemaker_session=sagemaker_session,
565-
framework_version=tf_latest_version,
566-
py_version=tf_latest_py_version,
565+
framework_version=tf_full_version,
566+
py_version=tf_full_py_version,
567567
metric_definitions=[
568568
{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}
569569
],

tests/integ/test_tf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, tf
147147
)
148148

149149

150-
def test_mnist_async(sagemaker_session, cpu_instance_type, tf_latest_serving_version):
150+
def test_mnist_async(sagemaker_session, cpu_instance_type, tf_serving_version):
151151
estimator = TensorFlow(
152152
entry_point=SCRIPT,
153153
role=ROLE,
@@ -156,7 +156,7 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_latest_serving_ver
156156
py_version=PYTHON_VERSION,
157157
sagemaker_session=sagemaker_session,
158158
# testing py-sdk functionality, no need to run against all TF versions
159-
framework_version=tf_latest_serving_version,
159+
framework_version=tf_serving_version,
160160
tags=TAGS,
161161
)
162162
inputs = estimator.sagemaker_session.upload_data(

tests/integ/test_tf_efs_fsx.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def efs_fsx_setup(sagemaker_session, ec2_instance_type):
5555
reason="EFS integration tests need to be fixed before running in all regions.",
5656
)
5757
def test_mnist_efs(
58-
efs_fsx_setup, sagemaker_session, cpu_instance_type, tf_latest_version, tf_latest_py_version
58+
efs_fsx_setup, sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version
5959
):
6060
role = efs_fsx_setup["role_name"]
6161
subnets = [efs_fsx_setup["subnet_id"]]
@@ -67,8 +67,8 @@ def test_mnist_efs(
6767
train_instance_count=1,
6868
train_instance_type=cpu_instance_type,
6969
sagemaker_session=sagemaker_session,
70-
framework_version=tf_latest_version,
71-
py_version=tf_latest_py_version,
70+
framework_version=tf_full_version,
71+
py_version=tf_full_py_version,
7272
subnets=subnets,
7373
security_group_ids=security_group_ids,
7474
)
@@ -96,7 +96,7 @@ def test_mnist_efs(
9696
reason="EFS integration tests need to be fixed before running in all regions.",
9797
)
9898
def test_mnist_lustre(
99-
efs_fsx_setup, sagemaker_session, cpu_instance_type, tf_latest_version, tf_latest_py_version
99+
efs_fsx_setup, sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version
100100
):
101101
role = efs_fsx_setup["role_name"]
102102
subnets = [efs_fsx_setup["subnet_id"]]
@@ -108,8 +108,8 @@ def test_mnist_lustre(
108108
train_instance_count=1,
109109
train_instance_type=cpu_instance_type,
110110
sagemaker_session=sagemaker_session,
111-
framework_version=tf_latest_version,
112-
py_version=tf_latest_py_version,
111+
framework_version=tf_full_version,
112+
py_version=tf_full_py_version,
113113
subnets=subnets,
114114
security_group_ids=security_group_ids,
115115
)
@@ -133,7 +133,7 @@ def test_mnist_lustre(
133133
reason="EFS integration tests need to be fixed before running in all regions.",
134134
)
135135
def test_tuning_tf_efs(
136-
efs_fsx_setup, sagemaker_session, cpu_instance_type, tf_latest_version, tf_latest_py_version
136+
efs_fsx_setup, sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version
137137
):
138138
role = efs_fsx_setup["role_name"]
139139
subnets = [efs_fsx_setup["subnet_id"]]
@@ -145,8 +145,8 @@ def test_tuning_tf_efs(
145145
train_instance_count=1,
146146
train_instance_type=cpu_instance_type,
147147
sagemaker_session=sagemaker_session,
148-
framework_version=tf_latest_version,
149-
py_version=tf_latest_py_version,
148+
framework_version=tf_full_version,
149+
py_version=tf_full_py_version,
150150
subnets=subnets,
151151
security_group_ids=security_group_ids,
152152
)
@@ -182,7 +182,7 @@ def test_tuning_tf_efs(
182182
reason="EFS integration tests need to be fixed before running in all regions.",
183183
)
184184
def test_tuning_tf_lustre(
185-
efs_fsx_setup, sagemaker_session, cpu_instance_type, tf_latest_version, tf_latest_py_version
185+
efs_fsx_setup, sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version
186186
):
187187
role = efs_fsx_setup["role_name"]
188188
subnets = [efs_fsx_setup["subnet_id"]]
@@ -194,8 +194,8 @@ def test_tuning_tf_lustre(
194194
train_instance_count=1,
195195
train_instance_type=cpu_instance_type,
196196
sagemaker_session=sagemaker_session,
197-
framework_version=tf_latest_version,
198-
py_version=tf_latest_py_version,
197+
framework_version=tf_full_version,
198+
py_version=tf_full_py_version,
199199
subnets=subnets,
200200
security_group_ids=security_group_ids,
201201
)

tests/integ/test_transformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def test_transform_mxnet_logs(
345345

346346

347347
def test_transform_tf_kms_network_isolation(
348-
sagemaker_session, cpu_instance_type, tmpdir, tf_latest_serving_version
348+
sagemaker_session, cpu_instance_type, tmpdir, tf_serving_version
349349
):
350350
data_path = os.path.join(DATA_DIR, "tensorflow_mnist")
351351

@@ -354,7 +354,7 @@ def test_transform_tf_kms_network_isolation(
354354
role="SageMakerRole",
355355
train_instance_count=1,
356356
train_instance_type=cpu_instance_type,
357-
framework_version=tf_latest_serving_version,
357+
framework_version=tf_serving_version,
358358
py_version=PYTHON_VERSION,
359359
sagemaker_session=sagemaker_session,
360360
)

0 commit comments

Comments
 (0)