diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py index c424753286..adde1b5585 100644 --- a/src/sagemaker/git_utils.py +++ b/src/sagemaker/git_utils.py @@ -174,7 +174,7 @@ def _clone_command_for_github_like(git_config, dest_dir): CalledProcessError: If failed to clone git repo. """ is_https = git_config["repo"].startswith("https://") - is_ssh = git_config["repo"].startswith("git@") + is_ssh = git_config["repo"].startswith("git@") or git_config["repo"].startswith("ssh://") if not is_https and not is_ssh: raise ValueError("Invalid Git url provided.") if is_ssh: @@ -277,12 +277,16 @@ def _run_clone_command(repo_url, dest_dir): if repo_url.startswith("https://"): my_env["GIT_TERMINAL_PROMPT"] = "0" subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) - elif repo_url.startswith("git@"): - with tempfile.NamedTemporaryFile() as sshnoprompt: - with open(sshnoprompt.name, "w") as write_pipe: - write_pipe.write("ssh -oBatchMode=yes $@") - os.chmod(sshnoprompt.name, 0o511) - my_env["GIT_SSH"] = sshnoprompt.name + elif repo_url.startswith("git@") or repo_url.startswith("ssh://"): + try: + with tempfile.NamedTemporaryFile() as sshnoprompt: + with open(sshnoprompt.name, "w") as write_pipe: + write_pipe.write("ssh -oBatchMode=yes $@") + os.chmod(sshnoprompt.name, 0o511) + my_env["GIT_SSH"] = sshnoprompt.name + subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) + except subprocess.CalledProcessError: + del my_env["GIT_SSH"] subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)