Skip to content

feature: Support for remote docker host #2864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6670e30
fix: jumpstart model table (#2954)
bencrabtree Feb 24, 2022
4ce6623
change: update code to get commit_id in codepipeline (#2961)
navinsoni Feb 26, 2022
5c64e6c
feature: Data Serializer (#2956)
jeniyat Feb 28, 2022
3e73133
feature: Support remote docker host
samadwar Jan 22, 2022
9f4f949
feature: Add support for remote docker host
samadwar Jan 24, 2022
f9d1e2b
Added unit tests
samadwar Feb 22, 2022
6f6f399
fixed uneccessary
Feb 22, 2022
f166b60
change: update code to get commit_id in codepipeline (#2961)
navinsoni Feb 26, 2022
086258d
feature: Data Serializer (#2956)
jeniyat Feb 28, 2022
a39b750
change: reorganize test files for workflow (#2960)
qidewenwhen Mar 3, 2022
28fd737
feature: TensorFlow 2.4 for Neo (#2861)
Qingzi-Lan Mar 3, 2022
20df3d7
fix: Remove sagemaker_job_name from hyperparameters in TrainingStep (…
staubhp Mar 3, 2022
b9f90dc
fix: Style update in DataSerializer (#2962)
jeniyat Mar 3, 2022
6db3774
documentation: smddp doc update (#2968)
mchoi8739 Mar 4, 2022
d610bfb
fix: container env generation for S3 URI and add test for the same (#…
shreyapandit Mar 7, 2022
90fd4a4
Merge branch 'dev' into remote_docker_host
shreyapandit Mar 7, 2022
169dffd
documentation: update sagemaker training compiler docstring (#2969)
mchoi8739 Mar 7, 2022
4325fcd
feat: Python 3.9 for readthedocs (#2973)
ahsan-z-khan Mar 8, 2022
e8ac39f
Merge branch 'dev' into remote_docker_host
ahsan-z-khan Mar 8, 2022
059c9c7
Merge branch 'master' into remote_docker_host
ahsan-z-khan Mar 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/doc_utils/jumpstart_doc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def create_jumpstart_model_table():
file_content.append(" - Latest Version\n")
file_content.append(" - Min SDK Version\n")

for model in sorted(sdk_manifest, key=lambda elt: elt["model_id"]):
for model in sdk_manifest_top_versions_for_models.values():
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
file_content.append(" * - {}\n".format(model["model_id"]))
file_content.append(" - {}\n".format(model_spec["training_supported"]))
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/local/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import sagemaker.local.data
from sagemaker.local.image import _SageMakerContainer
from sagemaker.local.utils import copy_directory_structure, move_to_destination
from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host
from sagemaker.utils import DeferredError, get_config_value

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -295,7 +295,7 @@ def start(self, input_data, output_data, transform_resources, **kwargs):
_wait_for_serving_container(serving_port)

# Get capabilities from Container if needed
endpoint_url = "http://localhost:%s/execution-parameters" % serving_port
endpoint_url = "http://%s:%d/execution-parameters" % (get_docker_host(), serving_port)
response, code = _perform_request(endpoint_url)
if code == 200:
execution_parameters = json.loads(response.read())
Expand Down Expand Up @@ -607,7 +607,7 @@ def _wait_for_serving_container(serving_port):
i = 0
http = urllib3.PoolManager()

endpoint_url = "http://localhost:%s/ping" % serving_port
endpoint_url = "http://%s:%d/ping" % (get_docker_host(), serving_port)
while True:
i += 5
if i >= HEALTH_CHECK_TIMEOUT_LIMIT:
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from botocore.exceptions import ClientError

from sagemaker.local.image import _SageMakerContainer
from sagemaker.local.utils import get_docker_host
from sagemaker.local.entities import (
_LocalEndpointConfig,
_LocalEndpoint,
Expand Down Expand Up @@ -448,7 +449,7 @@ def invoke_endpoint(
Returns:
object: Inference for the given input.
"""
url = "http://localhost:%s/invocations" % self.serving_port
url = "http://%s:%d/invocations" % (get_docker_host(), self.serving_port)
headers = {}

if ContentType is not None:
Expand Down
25 changes: 25 additions & 0 deletions src/sagemaker/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import shutil
import subprocess
import json

from distutils.dir_util import copy_tree
from six.moves.urllib.parse import urlparse
Expand Down Expand Up @@ -127,3 +128,27 @@ def get_child_process_ids(pid):
return pids + get_child_process_ids(child_pid)
else:
return []


def get_docker_host():
"""Discover remote docker host address (if applicable) or use "localhost"

Use "docker context inspect" to read current docker host endpoint url,
url must start with "tcp://"

Args:

Returns:
docker_host (str): Docker host DNS or IP address
"""
cmd = "docker context inspect".split()
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
if err:
return "localhost"
docker_context_string = output.decode("utf-8")
docker_context_host_url = json.loads(docker_context_string)[0]["Endpoints"]["docker"]["Host"]
parsed_url = urlparse(docker_context_host_url)
if parsed_url.hostname and parsed_url.scheme == "tcp":
return parsed_url.hostname
return "localhost"
36 changes: 35 additions & 1 deletion src/sagemaker/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import csv
import io
import json

import numpy as np
from six import with_metaclass

Expand Down Expand Up @@ -357,3 +356,38 @@ def serialize(self, data):
return data.read()

raise ValueError("Unable to handle input format: %s" % type(data))


class DataSerializer(SimpleBaseSerializer):
"""Serialize data in any file by extracting raw bytes from the file."""

def __init__(self, content_type="file-path/raw-bytes"):
"""Initialize a ``DataSerializer`` instance.

Args:
content_type (str): The MIME type to signal to the inference endpoint when sending
request data (default: "file-path/raw-bytes").
"""
super(DataSerializer, self).__init__(content_type=content_type)

def serialize(self, data):
"""Serialize file data to a raw bytes.

Args:
data (object): Data to be serialized. The data can be a string
representing file-path or the raw bytes from a file.
Returns:
raw-bytes: The data serialized as raw-bytes from the input.
"""
if isinstance(data, str):
try:
dataFile = open(data, "rb")
dataFileInfo = dataFile.read()
dataFile.close()
return dataFileInfo
except Exception as e:
raise ValueError(f"Could not open/read file: {data}. {e}")
if isinstance(data, bytes):
return data

raise ValueError(f"Object of type {type(data)} is not Data serializable.")
Binary file added tests/data/cuteCat.raw
Binary file not shown.
9 changes: 8 additions & 1 deletion tests/scripts/run-notebook-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,15 @@ echo "$LIFECYCLE_CONFIG_CONTENT"

set -euo pipefail

# git doesn't work in codepipeline, use CODEBUILD_RESOLVED_SOURCE_VERSION to get commit id
codebuild_initiator="${CODEBUILD_INITIATOR:-0}"
if [ "${codebuild_initiator:0:12}" == "codepipeline" ]; then
COMMIT_ID="${CODEBUILD_RESOLVED_SOURCE_VERSION}"
else
COMMIT_ID=$(git rev-parse --short HEAD)
fi

ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text)
COMMIT_ID=$(git rev-parse --short HEAD)
LIFECYCLE_CONFIG_NAME="install-python-sdk-$COMMIT_ID"

python setup.py sdist
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/sagemaker/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
SparseMatrixSerializer,
JSONLinesSerializer,
LibSVMSerializer,
DataSerializer,
)
from tests.unit import DATA_DIR

Expand Down Expand Up @@ -331,3 +332,26 @@ def test_libsvm_serializer_file_like(libsvm_serializer):
libsvm_file.seek(0)
result = libsvm_serializer.serialize(libsvm_file)
assert result == validation_data


@pytest.fixture
def data_serializer():
return DataSerializer()


def test_data_serializer_raw(data_serializer):
input_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.jpg")
with open(input_image_file_path, "rb") as image:
input_image = image.read()
input_image_data = data_serializer.serialize(input_image)
validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw")
validation_image_data = open(validation_image_file_path, "rb").read()
assert input_image_data == validation_image_data


def test_data_serializer_file_like(data_serializer):
input_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.jpg")
validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw")
input_image_data = data_serializer.serialize(input_image_file_path)
validation_image_data = open(validation_image_file_path, "rb").read()
assert input_image_data == validation_image_data
25 changes: 25 additions & 0 deletions tests/unit/test_local_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,28 @@ def test_local_transform_job_perform_batch_inference(
assert len(output_files) == 2
assert "file1.out" in output_files
assert "file2.out" in output_files


@patch("sagemaker.local.entities._SageMakerContainer", Mock())
@patch("sagemaker.local.entities.get_docker_host")
@patch("sagemaker.local.entities._perform_request")
@patch("sagemaker.local.entities._LocalTransformJob._perform_batch_inference")
def test_start_local_transform_job_from_remote_docker_host(
m_perform_batch_inference, m_perform_request, m_get_docker_host, local_transform_job
):
input_data = {}
output_data = {}
transform_resources = {"InstanceType": "local"}
m_get_docker_host.return_value = "some_host"
perform_request_mock = Mock()
m_perform_request.return_value = (perform_request_mock, 200)
perform_request_mock.read.return_value = '{"BatchStrategy": "SingleRecord"}'
local_transform_job.primary_container["ModelDataUrl"] = "file:///some/model"
local_transform_job.start(input_data, output_data, transform_resources, Environment={})
endpoints = [
"http://%s:%d/ping" % ("some_host", 8080),
"http://%s:%d/execution-parameters" % ("some_host", 8080),
]
calls = m_perform_request.call_args_list
for call, endpoint in zip(calls, endpoints):
assert call[0][0] == endpoint
15 changes: 15 additions & 0 deletions tests/unit/test_local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,3 +857,18 @@ def test_local_session_download_with_custom_s3_endpoint_url(sagemaker_session_cu
Filename="{}/{}".format(DOWNLOAD_DATA_TESTS_FILES_DIR, "test.csv"),
ExtraArgs=None,
)


@patch("sagemaker.local.local_session.get_docker_host")
@patch("urllib3.PoolManager.request")
def test_invoke_local_endpoint_with_remote_docker_host(
m_request,
m_get_docker_host,
):
m_get_docker_host.return_value = "some_host"
Body = "Body".encode("utf-8")
url = "http://%s:%d/invocations" % ("some_host", 8080)
sagemaker.local.local_session.LocalSagemakerRuntimeClient().invoke_endpoint(
Body, "local_endpoint"
)
m_request.assert_called_with("POST", url, body=Body, preload_content=False, headers={})
24 changes: 24 additions & 0 deletions tests/unit/test_local_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,27 @@ def test_get_child_process_ids(m_subprocess):
m_subprocess.Popen.return_value = process_mock
sagemaker.local.utils.get_child_process_ids("pid")
m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE)


@patch("sagemaker.local.utils.subprocess")
def test_get_docker_host(m_subprocess):
cmd = "docker context inspect".split()
process_mock = Mock()
endpoints = [
{"test": "tcp://host:port", "result": "host"},
{"test": "fd://something", "result": "localhost"},
{"test": "unix://path/to/socket", "result": "localhost"},
{"test": "npipe:////./pipe/foo", "result": "localhost"},
]
for endpoint in endpoints:
return_value = (
'[\n{\n"Endpoints":{\n"docker":{\n"Host": "%s"}\n}\n}\n]\n' % endpoint["test"]
)
attrs = {"communicate.return_value": (return_value.encode("utf-8"), None), "returncode": 0}
process_mock.configure_mock(**attrs)
m_subprocess.Popen.return_value = process_mock
host = sagemaker.local.utils.get_docker_host()
m_subprocess.Popen.assert_called_with(
cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE
)
assert host == endpoint["result"]