Skip to content

Commit c24e0b5

Browse files
authored
infra: use fixture for Python version in TF integ tests (#1617)
1 parent acbe02b commit c24e0b5

10 files changed

+155
-107
lines changed

tests/conftest.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020
import tests.integ
2121
from botocore.config import Config
22+
from packaging.version import Version
2223

2324
from sagemaker import Session, utils
2425
from sagemaker.local import LocalSession
@@ -57,7 +58,6 @@ def pytest_addoption(parser):
5758
"--rl-ray-full-version", action="store", default=RLEstimator.RAY_LATEST_VERSION
5859
)
5960
parser.addoption("--sklearn-full-version", action="store", default="0.20.0")
60-
parser.addoption("--tf-full-version", action="store", default="2.2.0")
6161
parser.addoption("--ei-tf-full-version", action="store")
6262
parser.addoption("--xgboost-full-version", action="store", default="1.0-1")
6363

@@ -304,35 +304,45 @@ def sklearn_full_version(request):
304304

305305

306306
@pytest.fixture(scope="module")
307-
def tf_full_version(request):
308-
return request.config.getoption("--tf-full-version")
307+
def tf_training_latest_version():
308+
return "2.2.0"
309+
310+
311+
@pytest.fixture(scope="module")
312+
def tf_training_latest_py_version():
313+
return "py37"
314+
315+
316+
@pytest.fixture(scope="module")
317+
def tf_serving_latest_version():
318+
return "2.1.0"
319+
320+
321+
@pytest.fixture(scope="module")
322+
def tf_full_version(tf_training_latest_version, tf_serving_latest_version):
323+
"""Fixture for TF tests that test both training and inference.
324+
325+
Fixture exists as such, since TF training and TFS have different latest versions.
326+
Otherwise, this would simply be a single latest version.
327+
"""
328+
return str(min(Version(tf_training_latest_version), Version(tf_serving_latest_version)))
309329

310330

311331
@pytest.fixture(scope="module")
312332
def tf_full_py_version(tf_full_version):
313-
"""fixture to match tf_full_version
333+
"""Fixture to match tf_full_version
314334
315-
Fixture exists as such, since tf_full_version may be overridden --tf-full-version.
335+
Fixture exists as such, since TF training and TFS have different latest versions.
316336
Otherwise, this would simply be py37 to match the latest version support.
317-
318-
TODO: Evaluate use of --tf-full-version with possible eye to remove and simplify code.
319337
"""
320-
version = [int(val) for val in tf_full_version.split(".")]
321-
if version < [1, 11]:
338+
version = Version(tf_full_version)
339+
if version < Version("1.11"):
322340
return "py2"
323-
if version < [2, 2]:
341+
if version < Version("2.2"):
324342
return "py3"
325343
return "py37"
326344

327345

328-
@pytest.fixture(scope="module")
329-
def tf_serving_version(tf_full_version):
330-
full_version = [int(val) for val in tf_full_version.split(".")]
331-
if full_version < [2, 2]:
332-
return tf_full_version
333-
return "2.1.0"
334-
335-
336346
@pytest.fixture(scope="module", params=["1.15.0", "2.0.0"])
337347
def ei_tf_full_version(request):
338348
tf_ei_version = request.config.getoption("--ei-tf-full-version")

tests/integ/test_airflow_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def test_sklearn_airflow_config_uploads_data_source_to_s3(
512512

513513
@pytest.mark.canary_quick
514514
def test_tf_airflow_config_uploads_data_source_to_s3(
515-
sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version
515+
sagemaker_session, cpu_instance_type, tf_training_latest_version, tf_training_latest_py_version
516516
):
517517
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
518518
tf = TensorFlow(
@@ -524,8 +524,8 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
524524
train_instance_count=SINGLE_INSTANCE_COUNT,
525525
train_instance_type=cpu_instance_type,
526526
sagemaker_session=sagemaker_session,
527-
framework_version=tf_full_version,
528-
py_version=tf_full_py_version,
527+
framework_version=tf_training_latest_version,
528+
py_version=tf_training_latest_py_version,
529529
metric_definitions=[
530530
{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}
531531
],

tests/integ/test_data_capture_config.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242

4343
def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
44-
sagemaker_session, tf_serving_version
44+
sagemaker_session, tf_serving_latest_version
4545
):
4646
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
4747
model_data = sagemaker_session.upload_data(
@@ -52,7 +52,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
5252
model = TensorFlowModel(
5353
model_data=model_data,
5454
role=ROLE,
55-
framework_version=tf_serving_version,
55+
framework_version=tf_serving_latest_version,
5656
sagemaker_session=sagemaker_session,
5757
)
5858
predictor = model.deploy(
@@ -98,7 +98,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
9898

9999

100100
def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
101-
sagemaker_session, tf_serving_version
101+
sagemaker_session, tf_serving_latest_version
102102
):
103103
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
104104
model_data = sagemaker_session.upload_data(
@@ -109,7 +109,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
109109
model = TensorFlowModel(
110110
model_data=model_data,
111111
role=ROLE,
112-
framework_version=tf_serving_version,
112+
framework_version=tf_serving_latest_version,
113113
sagemaker_session=sagemaker_session,
114114
)
115115
destination_s3_uri = os.path.join(
@@ -184,7 +184,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
184184

185185

186186
def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
187-
sagemaker_session, tf_serving_version
187+
sagemaker_session, tf_serving_latest_version
188188
):
189189
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
190190
model_data = sagemaker_session.upload_data(
@@ -195,7 +195,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
195195
model = TensorFlowModel(
196196
model_data=model_data,
197197
role=ROLE,
198-
framework_version=tf_serving_version,
198+
framework_version=tf_serving_latest_version,
199199
sagemaker_session=sagemaker_session,
200200
)
201201
destination_s3_uri = os.path.join(

tests/integ/test_horovod.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import json
1616
import os
1717
import tarfile
18-
from six.moves.urllib.parse import urlparse
1918

2019
import boto3
2120
import pytest
21+
from six.moves.urllib.parse import urlparse
2222

2323
import sagemaker.utils
2424
import tests.integ as integ
@@ -28,27 +28,49 @@
2828
horovod_dir = os.path.join(os.path.dirname(__file__), "..", "data", "horovod")
2929

3030

31-
@pytest.fixture(scope="module")
32-
def gpu_instance_type(request):
33-
return "ml.p2.xlarge"
34-
35-
3631
@pytest.mark.canary_quick
37-
def test_hvd_cpu(sagemaker_session, cpu_instance_type, tmpdir):
38-
_create_and_fit_estimator(sagemaker_session, cpu_instance_type, tmpdir)
32+
def test_hvd_cpu(
33+
sagemaker_session,
34+
tf_training_latest_version,
35+
tf_training_latest_py_version,
36+
cpu_instance_type,
37+
tmpdir,
38+
):
39+
_create_and_fit_estimator(
40+
sagemaker_session,
41+
tf_training_latest_version,
42+
tf_training_latest_py_version,
43+
cpu_instance_type,
44+
tmpdir,
45+
)
3946

4047

4148
@pytest.mark.canary_quick
4249
@pytest.mark.skipif(
4350
integ.test_region() in integ.TRAINING_NO_P2_REGIONS, reason="no ml.p2 instances in this region"
4451
)
45-
def test_hvd_gpu(sagemaker_session, gpu_instance_type, tmpdir):
46-
_create_and_fit_estimator(sagemaker_session, gpu_instance_type, tmpdir)
52+
def test_hvd_gpu(
53+
sagemaker_session, tf_training_latest_version, tf_training_latest_py_version, tmpdir
54+
):
55+
_create_and_fit_estimator(
56+
sagemaker_session,
57+
tf_training_latest_version,
58+
tf_training_latest_py_version,
59+
"ml.p2.xlarge",
60+
tmpdir,
61+
)
4762

4863

4964
@pytest.mark.local_mode
5065
@pytest.mark.parametrize("instances, processes", [[1, 2], (2, 1), (2, 2)])
51-
def test_horovod_local_mode(sagemaker_local_session, instances, processes, tmpdir):
66+
def test_horovod_local_mode(
67+
sagemaker_local_session,
68+
tf_training_latest_version,
69+
tf_training_latest_py_version,
70+
instances,
71+
processes,
72+
tmpdir,
73+
):
5274
output_path = "file://%s" % tmpdir
5375
job_name = sagemaker.utils.unique_name_from_base("tf-horovod")
5476
estimator = TensorFlow(
@@ -57,9 +79,9 @@ def test_horovod_local_mode(sagemaker_local_session, instances, processes, tmpdi
5779
train_instance_count=2,
5880
train_instance_type="local",
5981
sagemaker_session=sagemaker_local_session,
60-
py_version=integ.PYTHON_VERSION,
6182
output_path=output_path,
62-
framework_version="1.12",
83+
framework_version=tf_training_latest_version,
84+
py_version=tf_training_latest_py_version,
6385
distributions={"mpi": {"enabled": True, "processes_per_host": processes}},
6486
)
6587

@@ -96,16 +118,16 @@ def extract_files_from_s3(s3_url, tmpdir, sagemaker_session):
96118
tar_file.extractall(tmpdir)
97119

98120

99-
def _create_and_fit_estimator(sagemaker_session, instance_type, tmpdir):
121+
def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instance_type, tmpdir):
100122
job_name = sagemaker.utils.unique_name_from_base("tf-horovod")
101123
estimator = TensorFlow(
102124
entry_point=os.path.join(horovod_dir, "hvd_basic.py"),
103125
role="SageMakerRole",
104126
train_instance_count=2,
105127
train_instance_type=instance_type,
106128
sagemaker_session=sagemaker_session,
107-
py_version=integ.PYTHON_VERSION,
108-
framework_version="1.12",
129+
py_version=py_version,
130+
framework_version=tf_version,
109131
distributions={"mpi": {"enabled": True}},
110132
)
111133

tests/integ/test_model_monitor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888

8989

9090
@pytest.fixture(scope="module")
91-
def predictor(sagemaker_session, tf_serving_version):
91+
def predictor(sagemaker_session, tf_serving_latest_version):
9292
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
9393
model_data = sagemaker_session.upload_data(
9494
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
@@ -100,7 +100,7 @@ def predictor(sagemaker_session, tf_serving_version):
100100
model = TensorFlowModel(
101101
model_data=model_data,
102102
role=ROLE,
103-
framework_version=tf_serving_version,
103+
framework_version=tf_serving_latest_version,
104104
sagemaker_session=sagemaker_session,
105105
)
106106
predictor = model.deploy(

tests/integ/test_tf.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp
2323

2424
import tests.integ
25-
from tests.integ import kms_utils, timeout, PYTHON_VERSION
25+
from tests.integ import kms_utils, timeout
2626
from tests.integ.retry import retries
2727
from tests.integ.s3_utils import assert_s3_files_exist
2828

@@ -39,7 +39,7 @@
3939

4040

4141
def test_mnist_with_checkpoint_config(
42-
sagemaker_session, instance_type, tf_full_version, tf_full_py_version
42+
sagemaker_session, instance_type, tf_training_latest_version, tf_training_latest_py_version
4343
):
4444
checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}".format(
4545
sagemaker_session.default_bucket(), sagemaker_timestamp()
@@ -51,8 +51,8 @@ def test_mnist_with_checkpoint_config(
5151
train_instance_count=1,
5252
train_instance_type=instance_type,
5353
sagemaker_session=sagemaker_session,
54-
framework_version=tf_full_version,
55-
py_version=tf_full_py_version,
54+
framework_version=tf_training_latest_version,
55+
py_version=tf_training_latest_py_version,
5656
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
5757
checkpoint_s3_uri=checkpoint_s3_uri,
5858
checkpoint_local_path=checkpoint_local_path,
@@ -82,7 +82,7 @@ def test_mnist_with_checkpoint_config(
8282
assert actual_training_checkpoint_config == expected_training_checkpoint_config
8383

8484

85-
def test_server_side_encryption(sagemaker_session, tf_serving_version):
85+
def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_version):
8686
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
8787
output_path = os.path.join(
8888
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
@@ -95,8 +95,8 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version):
9595
train_instance_count=1,
9696
train_instance_type="ml.c5.xlarge",
9797
sagemaker_session=sagemaker_session,
98-
framework_version=tf_serving_version,
99-
py_version=PYTHON_VERSION,
98+
framework_version=tf_full_version,
99+
py_version=tf_full_py_version,
100100
code_location=output_path,
101101
output_path=output_path,
102102
model_dir="/opt/ml/model",
@@ -123,15 +123,17 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version):
123123

124124

125125
@pytest.mark.canary_quick
126-
def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, tf_full_py_version):
126+
def test_mnist_distributed(
127+
sagemaker_session, instance_type, tf_training_latest_version, tf_training_latest_py_version
128+
):
127129
estimator = TensorFlow(
128130
entry_point=SCRIPT,
129131
role=ROLE,
130132
train_instance_count=2,
131133
train_instance_type=instance_type,
132134
sagemaker_session=sagemaker_session,
133-
framework_version=tf_full_version,
134-
py_version=tf_full_py_version,
135+
framework_version=tf_training_latest_version,
136+
py_version=tf_training_latest_py_version,
135137
distributions=PARAMETER_SERVER_DISTRIBUTION,
136138
)
137139
inputs = estimator.sagemaker_session.upload_data(
@@ -147,16 +149,15 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, tf
147149
)
148150

149151

150-
def test_mnist_async(sagemaker_session, cpu_instance_type, tf_serving_version):
152+
def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version):
151153
estimator = TensorFlow(
152154
entry_point=SCRIPT,
153155
role=ROLE,
154156
train_instance_count=1,
155157
train_instance_type="ml.c5.4xlarge",
156-
py_version=PYTHON_VERSION,
157158
sagemaker_session=sagemaker_session,
158-
# testing py-sdk functionality, no need to run against all TF versions
159-
framework_version=tf_serving_version,
159+
framework_version=tf_full_version,
160+
py_version=tf_full_py_version,
160161
tags=TAGS,
161162
)
162163
inputs = estimator.sagemaker_session.upload_data(
@@ -188,15 +189,17 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_serving_version):
188189
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
189190

190191

191-
def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_serving_version):
192+
def test_deploy_with_input_handlers(
193+
sagemaker_session, instance_type, tf_full_version, tf_full_py_version
194+
):
192195
estimator = TensorFlow(
193196
entry_point="training.py",
194197
source_dir=TFS_RESOURCE_PATH,
195198
role=ROLE,
196199
train_instance_count=1,
197200
train_instance_type=instance_type,
198-
framework_version=tf_serving_version,
199-
py_version=PYTHON_VERSION,
201+
framework_version=tf_full_version,
202+
py_version=tf_full_py_version,
200203
sagemaker_session=sagemaker_session,
201204
tags=TAGS,
202205
)

0 commit comments

Comments
 (0)