Skip to content

feature: support TensorFlow training 2.2 #1521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
May 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1c9dd16
update with aws/master
chuyang-deng May 15, 2020
aa349fd
upate with aws master
chuyang-deng May 18, 2020
c875c47
update with aws master
chuyang-deng May 19, 2020
3e23a39
feature: TensorFlow 2.2 support
May 19, 2020
e9a96ec
Merge branch 'master' into tf-2-2
chuyang-deng May 19, 2020
1d66433
prevent TFS test pulling 2.2.0 images
May 20, 2020
dc8e4f7
Merge branch 'tf-2-2' of github.com:ChuyangDeng/sagemaker-python-sdk …
May 20, 2020
98a0a5e
fix flake8 error
May 20, 2020
7f44c65
update tfs test versions
May 21, 2020
b33deb9
use tfs 2.1.0 in model monitor tests
May 21, 2020
8da6509
distinguish tf and tfs latest versions in test
May 21, 2020
4f85786
add latest serving version
May 21, 2020
cb2a489
let py_version pick up tf_training_version
May 21, 2020
c9dcdd0
Merge branch 'master' into tf-2-2
chuyang-deng May 21, 2020
ed177e8
fix typo
May 21, 2020
816030e
Merge branch 'tf-2-2' of github.com:ChuyangDeng/sagemaker-python-sdk …
May 21, 2020
0b77be7
no py37 for tfs 2.1
May 21, 2020
89bb6d6
fix black error
May 21, 2020
40da473
prevent tfs pulling 2.2 image
May 20, 2020
c251fcc
tf 2.2 using py37
May 21, 2020
e010da9
Merge branch 'tf-2-2' of github.com:ChuyangDeng/sagemaker-python-sdk …
May 21, 2020
edeb501
Merge branch 'master' into tf-2-2
chuyang-deng May 21, 2020
3255b04
test_data_capture_config use TFS latest version
May 21, 2020
403298e
Merge branch 'tf-2-2' of github.com:ChuyangDeng/sagemaker-python-sdk …
May 21, 2020
2a7f154
fix flake8 error
May 21, 2020
3b10a76
update model monitoring test
May 21, 2020
9c2bb6c
update python version for tuner integ test
May 21, 2020
fe14848
address comments
May 22, 2020
fa84753
Merge branch 'ChuyangDeng-tf-2-2'
May 22, 2020
ff376d8
fix import error
May 22, 2020
598f833
resolve conflicts
May 22, 2020
31110b4
hardcode py_version for tensorflow-training 2.2 tests
May 22, 2020
fa04c1a
import LATEST_SERVING_VERSION from TensorFlow.defaults
May 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/using_tf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ For general information about using the SageMaker Python SDK, see :ref:`overview

.. contents::

Supported versions of TensorFlow for Elastic Inference: ``1.11``, ``1.12``, ``1.13``, ``1.14``.
Supported versions of TensorFlow for Elastic Inference: ``1.11``, ``1.12``, ``1.13``, ``1.14``, ``1.15``, ``2.0``.


*****************************
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/tensorflow/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
This is no longer updated so as to not break existing workflows.
"""

LATEST_VERSION = "2.1.0"
LATEST_VERSION = "2.2.0"
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""

LATEST_SERVING_VERSION = "2.1.0"
"""The latest version of TensorFlow Serving included in the SageMaker pre-built Docker images."""

LATEST_PY2_VERSION = "2.1.0"
12 changes: 10 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from sagemaker.pytorch import PyTorch
from sagemaker.rl import RLEstimator
from sagemaker.sklearn.defaults import SKLEARN_VERSION
from sagemaker.tensorflow.estimator import TensorFlow
from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow.defaults import LATEST_VERSION, LATEST_SERVING_VERSION

DEFAULT_REGION = "us-west-2"
CUSTOM_BUCKET_NAME_PREFIX = "sagemaker-custom-bucket"
Expand Down Expand Up @@ -259,7 +260,7 @@ def sklearn_full_version(request):
return request.config.getoption("--sklearn-full-version")


@pytest.fixture(scope="module", params=[TensorFlow._LATEST_1X_VERSION, TensorFlow.LATEST_VERSION])
@pytest.fixture(scope="module", params=[TensorFlow._LATEST_1X_VERSION, LATEST_VERSION])
def tf_full_version(request):
tf_version = request.config.getoption("--tf-full-version")
if tf_version is None:
Expand Down Expand Up @@ -335,3 +336,10 @@ def pytest_generate_tests(metafunc):
@pytest.fixture(scope="module")
def xgboost_full_version(request):
return request.config.getoption("--xgboost-full-version")


@pytest.fixture(scope="module")
def tf_serving_version(tf_full_version):
if tf_full_version == LATEST_VERSION:
return LATEST_SERVING_VERSION
return tf_full_version
12 changes: 6 additions & 6 deletions tests/integ/test_data_capture_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
sagemaker_session, tf_full_version
sagemaker_session, tf_serving_version
):
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
model_data = sagemaker_session.upload_data(
Expand All @@ -52,7 +52,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
model = Model(
model_data=model_data,
role=ROLE,
framework_version=tf_full_version,
framework_version=tf_serving_version,
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(


def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
sagemaker_session, tf_full_version
sagemaker_session, tf_serving_version
):
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
model_data = sagemaker_session.upload_data(
Expand All @@ -109,7 +109,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
model = Model(
model_data=model_data,
role=ROLE,
framework_version=tf_full_version,
framework_version=tf_serving_version,
sagemaker_session=sagemaker_session,
)
destination_s3_uri = os.path.join(
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(


def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
sagemaker_session, tf_full_version
sagemaker_session, tf_serving_version
):
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
model_data = sagemaker_session.upload_data(
Expand All @@ -195,7 +195,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
model = Model(
model_data=model_data,
role=ROLE,
framework_version=tf_full_version,
framework_version=tf_serving_version,
sagemaker_session=sagemaker_session,
)
destination_s3_uri = os.path.join(
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_model_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@


@pytest.fixture(scope="module")
def predictor(sagemaker_session, tf_full_version):
def predictor(sagemaker_session, tf_serving_version):
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
model_data = sagemaker_session.upload_data(
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
Expand All @@ -100,7 +100,7 @@ def predictor(sagemaker_session, tf_full_version):
model = Model(
model_data=model_data,
role=ROLE,
framework_version=tf_full_version,
framework_version=tf_serving_version,
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(
Expand Down
23 changes: 12 additions & 11 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest

from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow.defaults import LATEST_SERVING_VERSION
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp

import tests.integ
Expand All @@ -40,10 +41,8 @@


@pytest.fixture(scope="module")
def py_version(tf_full_version):
return (
"py37" if tf_full_version == TensorFlow._LATEST_1X_VERSION else tests.integ.PYTHON_VERSION
)
def py_version(tf_full_version, tf_serving_version):
return "py37" if tf_full_version == tf_serving_version else tests.integ.PYTHON_VERSION


def test_mnist_with_checkpoint_config(
Expand All @@ -61,7 +60,7 @@ def test_mnist_with_checkpoint_config(
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=tf_full_version,
py_version=py_version,
py_version="py37",
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
Expand Down Expand Up @@ -91,7 +90,7 @@ def test_mnist_with_checkpoint_config(
assert actual_training_checkpoint_config == expected_training_checkpoint_config


def test_server_side_encryption(sagemaker_session, tf_full_version, py_version):
def test_server_side_encryption(sagemaker_session, tf_serving_version, py_version):
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
output_path = os.path.join(
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
Expand All @@ -105,7 +104,7 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, py_version):
train_instance_type="ml.c5.xlarge",
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=tf_full_version,
framework_version=tf_serving_version,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why switch to using a serving version for training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because this test is doing a deploy later and TFS does not support 2.2 yet.

py_version=py_version,
code_location=output_path,
output_path=output_path,
Expand Down Expand Up @@ -140,7 +139,7 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
train_instance_count=2,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
py_version=py_version,
py_version="py37",
script_mode=True,
framework_version=tf_full_version,
distributions=PARAMETER_SERVER_DISTRIBUTION,
Expand Down Expand Up @@ -168,7 +167,7 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
sagemaker_session=sagemaker_session,
script_mode=True,
# testing py-sdk functionality, no need to run against all TF versions
framework_version=TensorFlow.LATEST_VERSION,
framework_version=LATEST_SERVING_VERSION,
tags=TAGS,
)
inputs = estimator.sagemaker_session.upload_data(
Expand Down Expand Up @@ -200,7 +199,9 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)


def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_full_version, py_version):
def test_deploy_with_input_handlers(
sagemaker_session, instance_type, tf_serving_version, py_version
):
estimator = TensorFlow(
entry_point="training.py",
source_dir=TFS_RESOURCE_PATH,
Expand All @@ -210,7 +211,7 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_full_ve
py_version=py_version,
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=tf_full_version,
framework_version=tf_serving_version,
tags=TAGS,
)

Expand Down
12 changes: 6 additions & 6 deletions tests/integ/test_tfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@pytest.fixture(scope="module")
def tfs_predictor(sagemaker_session, tf_full_version):
def tfs_predictor(sagemaker_session, tf_serving_version):
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
model_data = sagemaker_session.upload_data(
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
Expand All @@ -37,7 +37,7 @@ def tfs_predictor(sagemaker_session, tf_full_version):
model = Model(
model_data=model_data,
role="SageMakerRole",
framework_version=tf_full_version,
framework_version=tf_serving_version,
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(1, "ml.c5.xlarge", endpoint_name=endpoint_name)
Expand All @@ -54,7 +54,7 @@ def tar_dir(directory, tmpdir):

@pytest.fixture
def tfs_predictor_with_model_and_entry_point_same_tar(
sagemaker_local_session, tf_full_version, tmpdir
sagemaker_local_session, tf_serving_version, tmpdir
):
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")

Expand All @@ -65,7 +65,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(
model = Model(
model_data="file://" + model_tar,
role="SageMakerRole",
framework_version=tf_full_version,
framework_version=tf_serving_version,
sagemaker_session=sagemaker_local_session,
)
predictor = model.deploy(1, "local", endpoint_name=endpoint_name)
Expand All @@ -78,7 +78,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(

@pytest.fixture(scope="module")
def tfs_predictor_with_model_and_entry_point_and_dependencies(
sagemaker_local_session, tf_full_version
sagemaker_local_session, tf_serving_version
):
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")

Expand All @@ -98,7 +98,7 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(
model_data=model_data,
role="SageMakerRole",
dependencies=dependencies,
framework_version=tf_full_version,
framework_version=tf_serving_version,
sagemaker_session=sagemaker_local_session,
)

Expand Down
3 changes: 2 additions & 1 deletion tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sagemaker.mxnet import MXNet
from sagemaker.pytorch import PyTorchModel
from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow.defaults import LATEST_SERVING_VERSION
from sagemaker.transformer import Transformer
from sagemaker.estimator import Estimator
from sagemaker.utils import unique_name_from_base
Expand Down Expand Up @@ -351,7 +352,7 @@ def test_transform_tf_kms_network_isolation(sagemaker_session, cpu_instance_type
role="SageMakerRole",
train_instance_count=1,
train_instance_type=cpu_instance_type,
framework_version=TensorFlow.LATEST_VERSION,
framework_version=LATEST_SERVING_VERSION,
script_mode=True,
py_version=PYTHON_VERSION,
sagemaker_session=sagemaker_session,
Expand Down
12 changes: 10 additions & 2 deletions tests/integ/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from sagemaker.predictor import json_deserializer
from sagemaker.pytorch import PyTorch
from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow.defaults import LATEST_VERSION
from sagemaker.tuner import (
IntegerParameter,
ContinuousParameter,
Expand All @@ -51,6 +52,13 @@

DATA_PATH = os.path.join(DATA_DIR, "iris", "data")

PY37_SUPPORTED_FRAMEWORK_VERSION = [TensorFlow._LATEST_1X_VERSION, LATEST_VERSION]


@pytest.fixture(scope="module")
def py_version(tf_full_version):
return "py37" if tf_full_version in PY37_SUPPORTED_FRAMEWORK_VERSION else PYTHON_VERSION


@pytest.fixture(scope="module")
def kmeans_train_set(sagemaker_session):
Expand Down Expand Up @@ -590,7 +598,7 @@ def test_tuning_mxnet(sagemaker_session, mxnet_full_version, cpu_instance_type):


@pytest.mark.canary_quick
def test_tuning_tf_script_mode(sagemaker_session, cpu_instance_type, tf_full_version):
def test_tuning_tf_script_mode(sagemaker_session, cpu_instance_type, tf_full_version, py_version):
resource_path = os.path.join(DATA_DIR, "tensorflow_mnist")
script_path = os.path.join(resource_path, "mnist.py")

Expand All @@ -601,7 +609,7 @@ def test_tuning_tf_script_mode(sagemaker_session, cpu_instance_type, tf_full_ver
train_instance_type=cpu_instance_type,
script_mode=True,
sagemaker_session=sagemaker_session,
py_version=PYTHON_VERSION,
py_version=py_version,
framework_version=tf_full_version,
)

Expand Down