18
18
import time
19
19
from typing import List
20
20
21
+ import paramiko
21
22
from utils import SM_EFA_NCCL_INSTANCES , SM_EFA_RDMA_INSTANCES , get_python_executable , logger
22
23
23
24
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
@@ -74,6 +75,24 @@ def start_sshd_daemon():
74
75
logger .info ("Started SSH daemon." )
75
76
76
77
78
+ class CustomHostKeyPolicy (paramiko .client .MissingHostKeyPolicy ):
79
+ def missing_host_key (self , client , hostname , key ):
80
+ """Accept host keys for algo-* hostnames, reject others.
81
+
82
+ Args:
83
+ client: The SSHClient instance
84
+ hostname: The hostname attempting to connect
85
+ key: The host key
86
+
87
+ Raises:
88
+ paramiko.SSHException: If hostname doesn't match algo-* pattern
89
+ """
90
+ if hostname .startswith ("algo-" ):
91
+ client .get_host_keys ().add (hostname , key .get_name (), key )
92
+ return
93
+ raise paramiko .SSHException (f"Unknown host key for { hostname } " )
94
+
95
+
77
96
def _can_connect (host : str , port : int = DEFAULT_SSH_PORT ) -> bool :
78
97
"""Check if the connection to the provided host and port is possible."""
79
98
try :
@@ -82,7 +101,7 @@ def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
82
101
logger .debug ("Testing connection to host %s" , host )
83
102
client = paramiko .SSHClient ()
84
103
client .load_system_host_keys ()
85
- client .set_missing_host_key_policy (paramiko . RejectPolicy ())
104
+ client .set_missing_host_key_policy (CustomHostKeyPolicy ())
86
105
client .connect (host , port = port )
87
106
client .close ()
88
107
logger .info ("Can connect to host %s" , host )
0 commit comments