10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
- """An entry point for runtime environment. This must be kept independent of SageMaker PySDK"""
13
+ """An utils function for runtime environment. This must be kept independent of SageMaker PySDK"""
14
14
from __future__ import absolute_import
15
15
16
16
import argparse
21
21
import time
22
22
from typing import List
23
23
24
+ import paramiko
25
+
24
26
if __package__ is None or __package__ == "" :
25
27
from runtime_environment_manager import (
26
28
get_logger ,
43
45
logger = get_logger ()
44
46
45
47
48
+ class CustomHostKeyPolicy (paramiko .client .MissingHostKeyPolicy ):
49
+ """Class to handle host key policy for SageMaker distributed training SSH connections.
50
+
51
+ Example:
52
+ >>> client = paramiko.SSHClient()
53
+ >>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
54
+ >>> # Will succeed for SageMaker algorithm containers
55
+ >>> client.connect('algo-1234.internal')
56
+ >>> # Will raise SSHException for other unknown hosts
57
+ >>> client.connect('unknown-host') # raises SSHException
58
+ """
59
+
60
+ def missing_host_key (self , client , hostname , key ):
61
+ """Accept host keys for algo-* hostnames, reject others.
62
+
63
+ Args:
64
+ client: The SSHClient instance
65
+ hostname: The hostname attempting to connect
66
+ key: The host key
67
+ Raises:
68
+ paramiko.SSHException: If hostname doesn't match algo-* pattern
69
+ """
70
+ if hostname .startswith ("algo-" ):
71
+ client .get_host_keys ().add (hostname , key .get_name (), key )
72
+ return
73
+ raise paramiko .SSHException (f"Unknown host key for { hostname } " )
74
+
75
+
46
76
def _parse_args (sys_args ):
47
77
"""Parses CLI arguments."""
48
78
parser = argparse .ArgumentParser ()
@@ -54,16 +84,12 @@ def _parse_args(sys_args):
54
84
def _can_connect (host : str , port : int = DEFAULT_SSH_PORT ) -> bool :
55
85
"""Check if the connection to the provided host and port is possible."""
56
86
try :
57
- import paramiko
58
-
59
- logger .debug ("Testing connection to host %s" , host )
60
- client = paramiko .SSHClient ()
61
- client .load_system_host_keys ()
62
- client .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
63
- client .connect (host , port = port )
64
- client .close ()
65
- logger .info ("Can connect to host %s" , host )
66
- return True
87
+ with paramiko .SSHClient () as client :
88
+ client .load_system_host_keys ()
89
+ client .set_missing_host_key_policy (CustomHostKeyPolicy ())
90
+ client .connect (host , port = port )
91
+ logger .info ("Can connect to host %s" , host )
92
+ return True
67
93
except Exception as e : # pylint: disable=W0703
68
94
logger .info ("Cannot connect to host %s" , host )
69
95
logger .debug ("Connection failed with exception: %s" , e )
0 commit comments