Skip to content

feature: use deep learning images #883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,56 @@
VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"]
VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}

MERGED_FRAMEWORKS_REPO_MAP = {
"tensorflow-scriptmode": "tensorflow-training",
"mxnet": "mxnet-training",
"tensorflow-serving": "tensorflow-inference",
"mxnet-serving": "mxnet-inference",
}

MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
"tensorflow-scriptmode": [1, 13, 1],
"mxnet": [1, 4, 1],
"tensorflow-serving": [1, 13, 0],
"mxnet-serving": [1, 4, 1],
}


def is_version_equal_or_higher(lowest_version, framework_version):
"""Determine whether the ``framework_version`` is equal to or higher than ``lowest_version``

Args:
lowest_version (List[int]): lowest version represented in an integer list
framework_version (str): framework version string

Returns:
bool: Whether or not framework_version is equal to or higher than lowest_version
"""
version_list = [int(s) for s in framework_version.split(".")]
return version_list >= lowest_version[0 : len(version_list)]


def _is_merged_versions(framework, framework_version):
lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework)
if lowest_version_list:
return is_version_equal_or_higher(lowest_version_list, framework_version)
else:
return False


def _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
is_py3 = py_version == "py3" or py_version is None
is_merged_versions = _is_merged_versions(framework, framework_version)
return (not is_gov_region) and is_merged_versions and is_py3 and accelerator_type is None


def _registry_id(region, framework, py_version, account, accelerator_type, framework_version):
if _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
return "763104351884"
else:
return VALID_ACCOUNTS_BY_REGION.get(region, account)


def create_image_uri(
region,
Expand Down Expand Up @@ -86,8 +136,15 @@ def create_image_uri(
if py_version and py_version not in VALID_PY_VERSIONS:
raise ValueError("invalid py_version argument: {}".format(py_version))

# Handle Account Number for Gov Cloud
account = VALID_ACCOUNTS_BY_REGION.get(region, account)
# Handle Account Number for Gov Cloud and frameworks with DLC merged images
account = _registry_id(
region=region,
framework=framework,
py_version=py_version,
account=account,
accelerator_type=accelerator_type,
framework_version=framework_version,
)

# Handle Local Mode
if instance_type.startswith("local"):
Expand Down Expand Up @@ -121,7 +178,14 @@ def create_image_uri(
):
framework += "-eia"

return "{}/sagemaker-{}:{}".format(get_ecr_image_uri_prefix(account, region), framework, tag)
if _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
return "{}/{}:{}".format(
get_ecr_image_uri_prefix(account, region), MERGED_FRAMEWORKS_REPO_MAP[framework], tag
)
else:
return "{}/sagemaker-{}:{}".format(
get_ecr_image_uri_prefix(account, region), framework, tag
)


def _accelerator_type_valid_for_framework(
Expand Down Expand Up @@ -264,7 +328,7 @@ def framework_name_from_image(image_name):
# extract framework, python version and image tag
# We must support both the legacy and current image name format.
name_pattern = re.compile(
r"^sagemaker(?:-rl)?-(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501
r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501
)
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")

Expand Down
4 changes: 4 additions & 0 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_mnist(sagemaker_session, instance_type):
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=TensorFlow.LATEST_VERSION,
py_version=tests.integ.PYTHON_VERSION,
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
)
inputs = estimator.sagemaker_session.upload_data(
Expand Down Expand Up @@ -98,6 +99,7 @@ def test_server_side_encryption(sagemaker_session):
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=TensorFlow.LATEST_VERSION,
py_version=tests.integ.PYTHON_VERSION,
code_location=output_path,
output_path=output_path,
model_dir="/opt/ml/model",
Expand Down Expand Up @@ -144,6 +146,7 @@ def test_mnist_async(sagemaker_session):
role=ROLE,
train_instance_count=1,
train_instance_type="ml.c5.4xlarge",
py_version=tests.integ.PYTHON_VERSION,
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=TensorFlow.LATEST_VERSION,
Expand Down Expand Up @@ -182,6 +185,7 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type):
role=ROLE,
train_instance_count=1,
train_instance_type=instance_type,
py_version=tests.integ.PYTHON_VERSION,
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=TensorFlow.LATEST_VERSION,
Expand Down
68 changes: 57 additions & 11 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,59 @@ def test_create_image_uri_gov_cloud():
)


def test_create_image_uri_merged():
image_uri = fw_utils.create_image_uri(
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py3"
)
assert (
image_uri
== "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.13.1-gpu-py3"
)

image_uri = fw_utils.create_image_uri(
"us-west-2", "tensorflow-serving", "ml.c4.2xlarge", "1.13.1"
)
assert (
image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.13.1-cpu"
)

image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py3")
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:1.4.1-gpu-py3"

image_uri = fw_utils.create_image_uri(
"us-west-2", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py3"
)
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.4.1-cpu-py3"


def test_create_image_uri_merged_py2():
image_uri = fw_utils.create_image_uri(
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py2"
)
assert (
image_uri
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.13.1-gpu-py2"
)

image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py2")
assert image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.1-gpu-py2"

image_uri = fw_utils.create_image_uri(
"us-west-2", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py2"
)
assert (
image_uri
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-serving:1.4.1-cpu-py2"
)


def test_create_image_uri_accelerator_tf():
image_uri = fw_utils.create_image_uri(
MOCK_REGION,
"tensorflow",
"ml.p3.2xlarge",
"1.0rc",
"py3",
accelerator_type="ml.eia1.medium",
MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium"
)
assert (
image_uri
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-eia:1.0rc-gpu-py3"
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-eia:1.0-gpu-py3"
)


Expand All @@ -156,13 +197,13 @@ def test_create_image_uri_accelerator_mxnet_serving():
MOCK_REGION,
"mxnet-serving",
"ml.p3.2xlarge",
"1.0rc",
"1.0",
"py3",
accelerator_type="ml.eia1.medium",
)
assert (
image_uri
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-serving-eia:1.0rc-gpu-py3"
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-serving-eia:1.0-gpu-py3"
)


Expand All @@ -171,13 +212,13 @@ def test_create_image_uri_local_sagemaker_notebook_accelerator():
MOCK_REGION,
"mxnet",
"ml.p3.2xlarge",
"1.0rc",
"1.0",
"py3",
accelerator_type="local_sagemaker_notebook",
)
assert (
image_uri
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0rc-gpu-py3"
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0-gpu-py3"
)


Expand Down Expand Up @@ -555,6 +596,11 @@ def test_framework_name_from_image_tf_scriptmode():
"scriptmode",
) == fw_utils.framework_name_from_image(image_name)

image_name = "123.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.13-cpu-py3"
assert ("tensorflow", "py3", "1.13-cpu-py3", "training") == fw_utils.framework_name_from_image(
image_name
)


def test_framework_name_from_image_rl():
image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-mxnet:toolkit1.1-gpu-py3"
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def test_script_mode_tensorboard(
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
framework_version="some_version",
framework_version="1.0",
script_mode=True,
)
popen().poll.return_value = None
Expand Down