Skip to content

Commit 86c9f7d

Browse files
committed
aligned mpi_utils_remote.py to mpi_utils.py for estimator
1 parent 8701782 commit 86c9f7d

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Sagemaker modules container_drivers directory."""
14+
from __future__ import absolute_import

src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py

+37-11
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""An entry point for runtime environment. This must be kept independent of SageMaker PySDK"""
13+
"""An utils function for runtime environment. This must be kept independent of SageMaker PySDK"""
1414
from __future__ import absolute_import
1515

1616
import argparse
@@ -21,6 +21,8 @@
2121
import time
2222
from typing import List
2323

24+
import paramiko
25+
2426
if __package__ is None or __package__ == "":
2527
from runtime_environment_manager import (
2628
get_logger,
@@ -43,6 +45,34 @@
4345
logger = get_logger()
4446

4547

48+
class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
49+
"""Class to handle host key policy for SageMaker distributed training SSH connections.
50+
51+
Example:
52+
>>> client = paramiko.SSHClient()
53+
>>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
54+
>>> # Will succeed for SageMaker algorithm containers
55+
>>> client.connect('algo-1234.internal')
56+
>>> # Will raise SSHException for other unknown hosts
57+
>>> client.connect('unknown-host') # raises SSHException
58+
"""
59+
60+
def missing_host_key(self, client, hostname, key):
61+
"""Accept host keys for algo-* hostnames, reject others.
62+
63+
Args:
64+
client: The SSHClient instance
65+
hostname: The hostname attempting to connect
66+
key: The host key
67+
Raises:
68+
paramiko.SSHException: If hostname doesn't match algo-* pattern
69+
"""
70+
if hostname.startswith("algo-"):
71+
client.get_host_keys().add(hostname, key.get_name(), key)
72+
return
73+
raise paramiko.SSHException(f"Unknown host key for {hostname}")
74+
75+
4676
def _parse_args(sys_args):
4777
"""Parses CLI arguments."""
4878
parser = argparse.ArgumentParser()
@@ -54,16 +84,12 @@ def _parse_args(sys_args):
5484
def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
5585
"""Check if the connection to the provided host and port is possible."""
5686
try:
57-
import paramiko
58-
59-
logger.debug("Testing connection to host %s", host)
60-
client = paramiko.SSHClient()
61-
client.load_system_host_keys()
62-
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
63-
client.connect(host, port=port)
64-
client.close()
65-
logger.info("Can connect to host %s", host)
66-
return True
87+
with paramiko.SSHClient() as client:
88+
client.load_system_host_keys()
89+
client.set_missing_host_key_policy(CustomHostKeyPolicy())
90+
client.connect(host, port=port)
91+
logger.info("Can connect to host %s", host)
92+
return True
6793
except Exception as e: # pylint: disable=W0703
6894
logger.info("Cannot connect to host %s", host)
6995
logger.debug("Connection failed with exception: %s", e)

0 commit comments

Comments
 (0)