Skip to content

Commit c9d89da

Browse files
chuyang-dengChuyang Deng
and
Chuyang Deng
authored
feature: support TFS 2.2 (#1705)
* feature: support TFS 2.2 * update transformer test * add 2.2.0 to fw_utils unit test * resolve black-format error Co-authored-by: Chuyang Deng <[email protected]>
1 parent 003a149 commit c9d89da

File tree

8 files changed

+40
-41
lines changed

8 files changed

+40
-41
lines changed

src/sagemaker/tensorflow/defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
LATEST_VERSION = "2.2.0"
2222
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
2323

24-
LATEST_SERVING_VERSION = "2.1.0"
24+
LATEST_SERVING_VERSION = "2.2.0"
2525
"""The latest version of TensorFlow Serving included in the SageMaker pre-built Docker images."""
2626

2727
LATEST_PY2_VERSION = "2.1.0"

tests/conftest.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from sagemaker.rl import RLEstimator
2929
from sagemaker.sklearn.defaults import SKLEARN_VERSION
3030
from sagemaker.tensorflow import TensorFlow
31-
from sagemaker.tensorflow.defaults import LATEST_VERSION, LATEST_SERVING_VERSION
31+
from sagemaker.tensorflow.defaults import LATEST_VERSION
3232

3333
DEFAULT_REGION = "us-west-2"
3434
CUSTOM_BUCKET_NAME_PREFIX = "sagemaker-custom-bucket"
@@ -336,10 +336,3 @@ def pytest_generate_tests(metafunc):
336336
@pytest.fixture(scope="module")
337337
def xgboost_full_version(request):
338338
return request.config.getoption("--xgboost-full-version")
339-
340-
341-
@pytest.fixture(scope="module")
342-
def tf_serving_version(tf_full_version):
343-
if tf_full_version == LATEST_VERSION:
344-
return LATEST_SERVING_VERSION
345-
return tf_full_version

tests/integ/test_data_capture_config.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242

4343
def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
44-
sagemaker_session, tf_serving_version
44+
sagemaker_session, tf_full_version
4545
):
4646
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
4747
model_data = sagemaker_session.upload_data(
@@ -52,7 +52,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
5252
model = Model(
5353
model_data=model_data,
5454
role=ROLE,
55-
framework_version=tf_serving_version,
55+
framework_version=tf_full_version,
5656
sagemaker_session=sagemaker_session,
5757
)
5858
predictor = model.deploy(
@@ -98,7 +98,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
9898

9999

100100
def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
101-
sagemaker_session, tf_serving_version
101+
sagemaker_session, tf_full_version
102102
):
103103
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
104104
model_data = sagemaker_session.upload_data(
@@ -109,7 +109,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
109109
model = Model(
110110
model_data=model_data,
111111
role=ROLE,
112-
framework_version=tf_serving_version,
112+
framework_version=tf_full_version,
113113
sagemaker_session=sagemaker_session,
114114
)
115115
destination_s3_uri = os.path.join(
@@ -184,7 +184,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
184184

185185

186186
def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
187-
sagemaker_session, tf_serving_version
187+
sagemaker_session, tf_full_version
188188
):
189189
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
190190
model_data = sagemaker_session.upload_data(
@@ -195,7 +195,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
195195
model = Model(
196196
model_data=model_data,
197197
role=ROLE,
198-
framework_version=tf_serving_version,
198+
framework_version=tf_full_version,
199199
sagemaker_session=sagemaker_session,
200200
)
201201
destination_s3_uri = os.path.join(

tests/integ/test_model_monitor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888

8989

9090
@pytest.fixture(scope="module")
91-
def predictor(sagemaker_session, tf_serving_version):
91+
def predictor(sagemaker_session, tf_full_version):
9292
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
9393
model_data = sagemaker_session.upload_data(
9494
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
@@ -100,7 +100,7 @@ def predictor(sagemaker_session, tf_serving_version):
100100
model = Model(
101101
model_data=model_data,
102102
role=ROLE,
103-
framework_version=tf_serving_version,
103+
framework_version=tf_full_version,
104104
sagemaker_session=sagemaker_session,
105105
)
106106
predictor = model.deploy(

tests/integ/test_tf_script_mode.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pytest
2020

2121
from sagemaker.tensorflow import TensorFlow
22-
from sagemaker.tensorflow.defaults import LATEST_SERVING_VERSION
22+
from sagemaker.tensorflow.defaults import LATEST_VERSION
2323
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp
2424

2525
import tests.integ
@@ -41,8 +41,8 @@
4141

4242

4343
@pytest.fixture(scope="module")
44-
def py_version(tf_full_version, tf_serving_version):
45-
return "py37" if tf_full_version == tf_serving_version else tests.integ.PYTHON_VERSION
44+
def py_version(tf_full_version):
45+
return "py37" if tf_full_version == LATEST_VERSION else tests.integ.PYTHON_VERSION
4646

4747

4848
def test_mnist_with_checkpoint_config(
@@ -60,7 +60,7 @@ def test_mnist_with_checkpoint_config(
6060
sagemaker_session=sagemaker_session,
6161
script_mode=True,
6262
framework_version=tf_full_version,
63-
py_version="py37",
63+
py_version=py_version,
6464
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
6565
checkpoint_s3_uri=checkpoint_s3_uri,
6666
checkpoint_local_path=checkpoint_local_path,
@@ -90,7 +90,7 @@ def test_mnist_with_checkpoint_config(
9090
assert actual_training_checkpoint_config == expected_training_checkpoint_config
9191

9292

93-
def test_server_side_encryption(sagemaker_session, tf_serving_version, py_version):
93+
def test_server_side_encryption(sagemaker_session, tf_full_version, py_version):
9494
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
9595
output_path = os.path.join(
9696
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
@@ -104,7 +104,7 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version, py_versio
104104
train_instance_type="ml.c5.xlarge",
105105
sagemaker_session=sagemaker_session,
106106
script_mode=True,
107-
framework_version=tf_serving_version,
107+
framework_version=tf_full_version,
108108
py_version=py_version,
109109
code_location=output_path,
110110
output_path=output_path,
@@ -139,7 +139,7 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
139139
train_instance_count=2,
140140
train_instance_type=instance_type,
141141
sagemaker_session=sagemaker_session,
142-
py_version="py37",
142+
py_version=py_version,
143143
script_mode=True,
144144
framework_version=tf_full_version,
145145
distributions=PARAMETER_SERVER_DISTRIBUTION,
@@ -163,11 +163,11 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
163163
role=ROLE,
164164
train_instance_count=1,
165165
train_instance_type="ml.c5.4xlarge",
166-
py_version=tests.integ.PYTHON_VERSION,
166+
py_version=py_version,
167167
sagemaker_session=sagemaker_session,
168168
script_mode=True,
169169
# testing py-sdk functionality, no need to run against all TF versions
170-
framework_version=LATEST_SERVING_VERSION,
170+
framework_version=tf_full_version,
171171
tags=TAGS,
172172
)
173173
inputs = estimator.sagemaker_session.upload_data(
@@ -199,9 +199,7 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
199199
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
200200

201201

202-
def test_deploy_with_input_handlers(
203-
sagemaker_session, instance_type, tf_serving_version, py_version
204-
):
202+
def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_full_version, py_version):
205203
estimator = TensorFlow(
206204
entry_point="training.py",
207205
source_dir=TFS_RESOURCE_PATH,
@@ -211,7 +209,7 @@ def test_deploy_with_input_handlers(
211209
py_version=py_version,
212210
sagemaker_session=sagemaker_session,
213211
script_mode=True,
214-
framework_version=tf_serving_version,
212+
framework_version=tf_full_version,
215213
tags=TAGS,
216214
)
217215

tests/integ/test_tfs.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
@pytest.fixture(scope="module")
30-
def tfs_predictor(sagemaker_session, tf_serving_version):
30+
def tfs_predictor(sagemaker_session, tf_full_version):
3131
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
3232
model_data = sagemaker_session.upload_data(
3333
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
@@ -37,7 +37,7 @@ def tfs_predictor(sagemaker_session, tf_serving_version):
3737
model = Model(
3838
model_data=model_data,
3939
role="SageMakerRole",
40-
framework_version=tf_serving_version,
40+
framework_version=tf_full_version,
4141
sagemaker_session=sagemaker_session,
4242
)
4343
predictor = model.deploy(1, "ml.c5.xlarge", endpoint_name=endpoint_name)
@@ -54,7 +54,7 @@ def tar_dir(directory, tmpdir):
5454

5555
@pytest.fixture
5656
def tfs_predictor_with_model_and_entry_point_same_tar(
57-
sagemaker_local_session, tf_serving_version, tmpdir
57+
sagemaker_local_session, tf_full_version, tmpdir
5858
):
5959
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
6060

@@ -65,7 +65,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(
6565
model = Model(
6666
model_data="file://" + model_tar,
6767
role="SageMakerRole",
68-
framework_version=tf_serving_version,
68+
framework_version=tf_full_version,
6969
sagemaker_session=sagemaker_local_session,
7070
)
7171
predictor = model.deploy(1, "local", endpoint_name=endpoint_name)
@@ -78,7 +78,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(
7878

7979
@pytest.fixture(scope="module")
8080
def tfs_predictor_with_model_and_entry_point_and_dependencies(
81-
sagemaker_local_session, tf_serving_version
81+
sagemaker_local_session, tf_full_version
8282
):
8383
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
8484

@@ -98,7 +98,7 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(
9898
model_data=model_data,
9999
role="SageMakerRole",
100100
dependencies=dependencies,
101-
framework_version=tf_serving_version,
101+
framework_version=tf_full_version,
102102
sagemaker_session=sagemaker_local_session,
103103
)
104104

tests/integ/test_transformer.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from sagemaker.mxnet import MXNet
2626
from sagemaker.pytorch import PyTorchModel
2727
from sagemaker.tensorflow import TensorFlow
28-
from sagemaker.tensorflow.defaults import LATEST_SERVING_VERSION
28+
from sagemaker.tensorflow.defaults import LATEST_VERSION
2929
from sagemaker.transformer import Transformer
3030
from sagemaker.estimator import Estimator
3131
from sagemaker.utils import unique_name_from_base
@@ -42,6 +42,11 @@
4242
MXNET_MNIST_PATH = os.path.join(DATA_DIR, "mxnet_mnist")
4343

4444

45+
@pytest.fixture(scope="module")
46+
def py_version(tf_full_version):
47+
return "py37" if tf_full_version == LATEST_VERSION else PYTHON_VERSION
48+
49+
4550
@pytest.fixture(scope="module")
4651
def mxnet_estimator(sagemaker_session, mxnet_full_version, cpu_instance_type):
4752
mx = MXNet(
@@ -364,17 +369,19 @@ def test_transform_mxnet_logs(
364369
transformer.wait()
365370

366371

367-
def test_transform_tf_kms_network_isolation(sagemaker_session, cpu_instance_type, tmpdir):
372+
def test_transform_tf_kms_network_isolation(
373+
sagemaker_session, cpu_instance_type, tmpdir, tf_full_version, py_version
374+
):
368375
data_path = os.path.join(DATA_DIR, "tensorflow_mnist")
369376

370377
tf = TensorFlow(
371378
entry_point=os.path.join(data_path, "mnist.py"),
372379
role="SageMakerRole",
373380
train_instance_count=1,
374381
train_instance_type=cpu_instance_type,
375-
framework_version=LATEST_SERVING_VERSION,
382+
framework_version=tf_full_version,
376383
script_mode=True,
377-
py_version=PYTHON_VERSION,
384+
py_version=py_version,
378385
sagemaker_session=sagemaker_session,
379386
)
380387

tests/unit/test_fw_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def is_mxnet_1_4_py2(framework, framework_version, py_version):
109109

110110

111111
@pytest.fixture(
112-
scope="module", params=["1.11", "1.11.0", "1.12", "1.12.0", "1.14", "1.14.0", "1.15", "1.15.0"]
112+
scope="module",
113+
params=["1.11", "1.11.0", "1.12", "1.12.0", "1.14", "1.14.0", "1.15", "1.15.0", "2.0", "2.2.0"],
113114
)
114115
def tf_version(request):
115116
return request.param

0 commit comments

Comments
 (0)