Skip to content

Commit c44cae6

Browse files
committed
Filter policy by algo-
1 parent dbf8759 commit c44cae6

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

src/sagemaker/modules/train/container_drivers/mpi_utils.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import time
1919
from typing import List
2020

21+
import paramiko
2122
from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger
2223

2324
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
@@ -74,6 +75,24 @@ def start_sshd_daemon():
7475
logger.info("Started SSH daemon.")
7576

7677

78+
class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
79+
def missing_host_key(self, client, hostname, key):
80+
"""Accept host keys for algo-* hostnames, reject others.
81+
82+
Args:
83+
client: The SSHClient instance
84+
hostname: The hostname attempting to connect
85+
key: The host key
86+
87+
Raises:
88+
paramiko.SSHException: If hostname doesn't match algo-* pattern
89+
"""
90+
if hostname.startswith("algo-"):
91+
client.get_host_keys().add(hostname, key.get_name(), key)
92+
return
93+
raise paramiko.SSHException(f"Unknown host key for {hostname}")
94+
95+
7796
def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
7897
"""Check if the connection to the provided host and port is possible."""
7998
try:
@@ -82,7 +101,7 @@ def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
82101
logger.debug("Testing connection to host %s", host)
83102
client = paramiko.SSHClient()
84103
client.load_system_host_keys()
85-
client.set_missing_host_key_policy(paramiko.RejectPolicy())
104+
client.set_missing_host_key_policy(CustomHostKeyPolicy())
86105
client.connect(host, port=port)
87106
client.close()
88107
logger.info("Can connect to host %s", host)

0 commit comments

Comments
 (0)