diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 45ff28c967..2636ecae1b 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -22,7 +22,7 @@ import sagemaker.local.data from sagemaker.local.image import _SageMakerContainer -from sagemaker.local.utils import copy_directory_structure, move_to_destination +from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host from sagemaker.utils import DeferredError, get_config_value logger = logging.getLogger(__name__) @@ -295,7 +295,7 @@ def start(self, input_data, output_data, transform_resources, **kwargs): _wait_for_serving_container(serving_port) # Get capabilities from Container if needed - endpoint_url = "http://localhost:%s/execution-parameters" % serving_port + endpoint_url = "http://%s:%d/execution-parameters" % (get_docker_host(), serving_port) response, code = _perform_request(endpoint_url) if code == 200: execution_parameters = json.loads(response.read()) @@ -607,7 +607,7 @@ def _wait_for_serving_container(serving_port): i = 0 http = urllib3.PoolManager() - endpoint_url = "http://localhost:%s/ping" % serving_port + endpoint_url = "http://%s:%d/ping" % (get_docker_host(), serving_port) while True: i += 5 if i >= HEALTH_CHECK_TIMEOUT_LIMIT: diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 4c2f02cad1..1406c1cdf7 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -21,6 +21,7 @@ from botocore.exceptions import ClientError from sagemaker.local.image import _SageMakerContainer +from sagemaker.local.utils import get_docker_host from sagemaker.local.entities import ( _LocalEndpointConfig, _LocalEndpoint, @@ -448,7 +449,7 @@ def invoke_endpoint( Returns: object: Inference for the given input. """ - url = "http://localhost:%s/invocations" % self.serving_port + url = "http://%s:%d/invocations" % (get_docker_host(), self.serving_port) headers = {} if ContentType is not None: diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 352b7ec387..1b3ea155e1 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -16,6 +16,7 @@ import os import shutil import subprocess +import json from distutils.dir_util import copy_tree from six.moves.urllib.parse import urlparse @@ -127,3 +128,27 @@ def get_child_process_ids(pid): return pids + get_child_process_ids(child_pid) else: return [] + + +def get_docker_host(): + """Discover remote docker host address (if applicable) or use "localhost" + + Use "docker context inspect" to read current docker host endpoint url, + url must start with "tcp://" + + Args: + + Returns: + docker_host (str): Docker host DNS or IP address + """ + cmd = "docker context inspect".split() + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, err = process.communicate() + if err: + return "localhost" + docker_context_string = output.decode("utf-8") + docker_context_host_url = json.loads(docker_context_string)[0]["Endpoints"]["docker"]["Host"] + parsed_url = urlparse(docker_context_host_url) + if parsed_url.hostname and parsed_url.scheme == "tcp": + return parsed_url.hostname + return "localhost" diff --git a/tests/unit/test_local_entities.py b/tests/unit/test_local_entities.py index 3f34cd851d..f7a56959db 100644 --- a/tests/unit/test_local_entities.py +++ b/tests/unit/test_local_entities.py @@ -163,3 +163,28 @@ def test_local_transform_job_perform_batch_inference( assert len(output_files) == 2 assert "file1.out" in output_files assert "file2.out" in output_files + + +@patch("sagemaker.local.entities._SageMakerContainer", Mock()) +@patch("sagemaker.local.entities.get_docker_host") +@patch("sagemaker.local.entities._perform_request") +@patch("sagemaker.local.entities._LocalTransformJob._perform_batch_inference") +def test_start_local_transform_job_from_remote_docker_host( + m_perform_batch_inference, m_perform_request, m_get_docker_host, local_transform_job +): + input_data = {} + output_data = {} + transform_resources = {"InstanceType": "local"} + m_get_docker_host.return_value = "some_host" + perform_request_mock = Mock() + m_perform_request.return_value = (perform_request_mock, 200) + perform_request_mock.read.return_value = '{"BatchStrategy": "SingleRecord"}' + local_transform_job.primary_container["ModelDataUrl"] = "file:///some/model" + local_transform_job.start(input_data, output_data, transform_resources, Environment={}) + endpoints = [ + "http://%s:%d/ping" % ("some_host", 8080), + "http://%s:%d/execution-parameters" % ("some_host", 8080), + ] + calls = m_perform_request.call_args_list + for call, endpoint in zip(calls, endpoints): + assert call[0][0] == endpoint diff --git a/tests/unit/test_local_session.py b/tests/unit/test_local_session.py index 6d64c73849..4b5801d971 100644 --- a/tests/unit/test_local_session.py +++ b/tests/unit/test_local_session.py @@ -857,3 +857,18 @@ def test_local_session_download_with_custom_s3_endpoint_url(sagemaker_session_cu Filename="{}/{}".format(DOWNLOAD_DATA_TESTS_FILES_DIR, "test.csv"), ExtraArgs=None, ) + + +@patch("sagemaker.local.local_session.get_docker_host") +@patch("urllib3.PoolManager.request") +def test_invoke_local_endpoint_with_remote_docker_host( + m_request, + m_get_docker_host, +): + m_get_docker_host.return_value = "some_host" + Body = "Body".encode("utf-8") + url = "http://%s:%d/invocations" % ("some_host", 8080) + sagemaker.local.local_session.LocalSagemakerRuntimeClient().invoke_endpoint( + Body, "local_endpoint" + ) + m_request.assert_called_with("POST", url, body=Body, preload_content=False, headers={}) diff --git a/tests/unit/test_local_utils.py b/tests/unit/test_local_utils.py index be54d00a19..4bce43704e 100644 --- a/tests/unit/test_local_utils.py +++ b/tests/unit/test_local_utils.py @@ -92,3 +92,27 @@ def test_get_child_process_ids(m_subprocess): m_subprocess.Popen.return_value = process_mock sagemaker.local.utils.get_child_process_ids("pid") m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE) + + +@patch("sagemaker.local.utils.subprocess") +def test_get_docker_host(m_subprocess): + cmd = "docker context inspect".split() + process_mock = Mock() + endpoints = [ + {"test": "tcp://host:port", "result": "host"}, + {"test": "fd://something", "result": "localhost"}, + {"test": "unix://path/to/socket", "result": "localhost"}, + {"test": "npipe:////./pipe/foo", "result": "localhost"}, + ] + for endpoint in endpoints: + return_value = ( + '[\n{\n"Endpoints":{\n"docker":{\n"Host": "%s"}\n}\n}\n]\n' % endpoint["test"] + ) + attrs = {"communicate.return_value": (return_value.encode("utf-8"), None), "returncode": 0} + process_mock.configure_mock(**attrs) + m_subprocess.Popen.return_value = process_mock + host = sagemaker.local.utils.get_docker_host() + m_subprocess.Popen.assert_called_with( + cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE + ) + assert host == endpoint["result"]