Skip to content

Commit 6181117

Browse files
committed
feature: Support remote docker host
Implemented get_docker_host() to enable sagemaker.local to work with remote docker host
1 parent 814dd3d commit 6181117

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

src/sagemaker/local/entities.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import sagemaker.local.data
2424
from sagemaker.local.image import _SageMakerContainer
25-
from sagemaker.local.utils import copy_directory_structure, move_to_destination
25+
from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host
2626
from sagemaker.utils import DeferredError, get_config_value
2727

2828
logger = logging.getLogger(__name__)
@@ -295,7 +295,7 @@ def start(self, input_data, output_data, transform_resources, **kwargs):
295295
_wait_for_serving_container(serving_port)
296296

297297
# Get capabilities from Container if needed
298-
endpoint_url = "http://localhost:%s/execution-parameters" % serving_port
298+
endpoint_url = "http://%s:%d/execution-parameters" % (get_docker_host(), serving_port)
299299
response, code = _perform_request(endpoint_url)
300300
if code == 200:
301301
execution_parameters = json.loads(response.read())
@@ -607,7 +607,7 @@ def _wait_for_serving_container(serving_port):
607607
i = 0
608608
http = urllib3.PoolManager()
609609

610-
endpoint_url = "http://localhost:%s/ping" % serving_port
610+
endpoint_url = "http://%s:%d/ping" % (get_docker_host(), serving_port)
611611
while True:
612612
i += 5
613613
if i >= HEALTH_CHECK_TIMEOUT_LIMIT:

src/sagemaker/local/local_session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from botocore.exceptions import ClientError
2222

2323
from sagemaker.local.image import _SageMakerContainer
24+
from sagemaker.local.utils import get_docker_host
2425
from sagemaker.local.entities import (
2526
_LocalEndpointConfig,
2627
_LocalEndpoint,
@@ -448,7 +449,7 @@ def invoke_endpoint(
448449
Returns:
449450
object: Inference for the given input.
450451
"""
451-
url = "http://localhost:%s/invocations" % self.serving_port
452+
url = "http://%s:%d/invocations" % (get_docker_host(), self.serving_port)
452453
headers = {}
453454

454455
if ContentType is not None:

src/sagemaker/local/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import shutil
1818
import subprocess
19+
import json
1920

2021
from distutils.dir_util import copy_tree
2122
from six.moves.urllib.parse import urlparse
@@ -127,3 +128,21 @@ def get_child_process_ids(pid):
127128
return pids + get_child_process_ids(child_pid)
128129
else:
129130
return []
131+
132+
def get_docker_host():
133+
"""Discover remote docker host address (if applicable) or use "localhost"
134+
135+
Use "docker context inspect" to read current docker host endpoint url
136+
137+
Args:
138+
139+
Returns:
140+
docker_host (str): Docker host DNS or IP
141+
"""
142+
try:
143+
docker_context_string = os.popen("docker context inspect").read()
144+
docker_context_host_url = json.loads(docker_context_string)[0]['Endpoints']['docker']['Host']
145+
docker_host = docker_context_host_url.split("//")[1].rsplit(":")[0] if not docker_context_host_url.startswith("unix") else "localhost"
146+
except:
147+
docker_host = "localhost"
148+
return docker_host

0 commit comments

Comments
 (0)