Skip to content

Commit 8c2012b

Browse files
authored
fix: Fix writing into non-closed file with git clone command (#4176)
1 parent 0b13d2a commit 8c2012b

File tree

2 files changed

+55
-24
lines changed

2 files changed

+55
-24
lines changed

src/sagemaker/git_utils.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import os
17+
from pathlib import Path
1718
import subprocess
1819
import tempfile
1920
import warnings
@@ -279,11 +280,13 @@ def _run_clone_command(repo_url, dest_dir):
279280
subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)
280281
elif repo_url.startswith("git@") or repo_url.startswith("ssh://"):
281282
try:
282-
with tempfile.NamedTemporaryFile() as sshnoprompt:
283-
with open(sshnoprompt.name, "w") as write_pipe:
284-
write_pipe.write("ssh -oBatchMode=yes $@")
285-
os.chmod(sshnoprompt.name, 0o511)
286-
my_env["GIT_SSH"] = sshnoprompt.name
283+
with tempfile.TemporaryDirectory() as tmp_dir:
284+
custom_ssh_executable = Path(tmp_dir) / "ssh_batch"
285+
with open(custom_ssh_executable, "w") as pipe:
286+
print("#!/bin/sh", file=pipe)
287+
print("ssh -oBatchMode=yes $@", file=pipe)
288+
os.chmod(custom_ssh_executable, 0o511)
289+
my_env["GIT_SSH"] = str(custom_ssh_executable)
287290
subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)
288291
except subprocess.CalledProcessError:
289292
del my_env["GIT_SSH"]

tests/unit/test_git_utils.py

+47-19
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pytest
1616
import os
17+
from pathlib import Path
1718
import subprocess
1819
from mock import patch, ANY
1920

@@ -34,10 +35,11 @@
3435

3536
@patch("subprocess.check_call")
3637
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
38+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
3739
@patch("os.path.isfile", return_value=True)
3840
@patch("os.path.isdir", return_value=True)
3941
@patch("os.path.exists", return_value=True)
40-
def test_git_clone_repo_succeed(exists, isdir, isfile, mkdtemp, check_call):
42+
def test_git_clone_repo_succeed(exists, isdir, isfile, tempdir, mkdtemp, check_call):
4143
git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT}
4244
entry_point = "entry_point"
4345
source_dir = "source_dir"
@@ -88,7 +90,8 @@ def test_git_clone_repo_git_argument_wrong_format():
8890
),
8991
)
9092
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
91-
def test_git_clone_repo_clone_fail(mkdtemp, check_call):
93+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
94+
def test_git_clone_repo_clone_fail(tempdir, mkdtemp, check_call):
9295
git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT}
9396
entry_point = "entry_point"
9497
source_dir = "source_dir"
@@ -103,7 +106,8 @@ def test_git_clone_repo_clone_fail(mkdtemp, check_call):
103106
side_effect=[True, subprocess.CalledProcessError(returncode=1, cmd="git checkout banana")],
104107
)
105108
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
106-
def test_git_clone_repo_branch_not_exist(mkdtemp, check_call):
109+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
110+
def test_git_clone_repo_branch_not_exist(tempdir, mkdtemp, check_call):
107111
git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT}
108112
entry_point = "entry_point"
109113
source_dir = "source_dir"
@@ -122,7 +126,8 @@ def test_git_clone_repo_branch_not_exist(mkdtemp, check_call):
122126
],
123127
)
124128
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
125-
def test_git_clone_repo_commit_not_exist(mkdtemp, check_call):
129+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
130+
def test_git_clone_repo_commit_not_exist(tempdir, mkdtemp, check_call):
126131
git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT}
127132
entry_point = "entry_point"
128133
source_dir = "source_dir"
@@ -134,10 +139,11 @@ def test_git_clone_repo_commit_not_exist(mkdtemp, check_call):
134139

135140
@patch("subprocess.check_call")
136141
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
142+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
137143
@patch("os.path.isfile", return_value=False)
138144
@patch("os.path.isdir", return_value=True)
139145
@patch("os.path.exists", return_value=True)
140-
def test_git_clone_repo_entry_point_not_exist(exists, isdir, isfile, mkdtemp, heck_call):
146+
def test_git_clone_repo_entry_point_not_exist(exists, isdir, isfile, tempdir, mkdtemp, heck_call):
141147
git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT}
142148
entry_point = "entry_point_that_does_not_exist"
143149
source_dir = "source_dir"
@@ -149,10 +155,11 @@ def test_git_clone_repo_entry_point_not_exist(exists, isdir, isfile, mkdtemp, he
149155

150156
@patch("subprocess.check_call")
151157
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
158+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
152159
@patch("os.path.isfile", return_value=True)
153160
@patch("os.path.isdir", return_value=False)
154161
@patch("os.path.exists", return_value=True)
155-
def test_git_clone_repo_source_dir_not_exist(exists, isdir, isfile, mkdtemp, check_call):
162+
def test_git_clone_repo_source_dir_not_exist(exists, isdir, isfile, tempdir, mkdtemp, check_call):
156163
git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT}
157164
entry_point = "entry_point"
158165
source_dir = "source_dir_that_does_not_exist"
@@ -164,10 +171,11 @@ def test_git_clone_repo_source_dir_not_exist(exists, isdir, isfile, mkdtemp, che
164171

165172
@patch("subprocess.check_call")
166173
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
174+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
167175
@patch("os.path.isfile", return_value=True)
168176
@patch("os.path.isdir", return_value=True)
169177
@patch("os.path.exists", side_effect=[True, False])
170-
def test_git_clone_repo_dependencies_not_exist(exists, isdir, isfile, mkdtemp, check_call):
178+
def test_git_clone_repo_dependencies_not_exist(exists, isdir, isfile, tempdir, mkdtemp, check_call):
171179
git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT}
172180
entry_point = "entry_point"
173181
source_dir = "source_dir"
@@ -179,8 +187,9 @@ def test_git_clone_repo_dependencies_not_exist(exists, isdir, isfile, mkdtemp, c
179187

180188
@patch("subprocess.check_call")
181189
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
190+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
182191
@patch("os.path.isfile", return_value=True)
183-
def test_git_clone_repo_with_username_password_no_2fa(isfile, mkdtemp, check_call):
192+
def test_git_clone_repo_with_username_password_no_2fa(isfile, tempdir, mkdtemp, check_call):
184193
git_config = {
185194
"repo": PRIVATE_GIT_REPO,
186195
"branch": PRIVATE_BRANCH,
@@ -210,8 +219,9 @@ def test_git_clone_repo_with_username_password_no_2fa(isfile, mkdtemp, check_cal
210219

211220
@patch("subprocess.check_call")
212221
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
222+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
213223
@patch("os.path.isfile", return_value=True)
214-
def test_git_clone_repo_with_token_no_2fa(isfile, mkdtemp, check_call):
224+
def test_git_clone_repo_with_token_no_2fa(isfile, tempdir, mkdtemp, check_call):
215225
git_config = {
216226
"repo": PRIVATE_GIT_REPO,
217227
"branch": PRIVATE_BRANCH,
@@ -236,8 +246,9 @@ def test_git_clone_repo_with_token_no_2fa(isfile, mkdtemp, check_call):
236246

237247
@patch("subprocess.check_call")
238248
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
249+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
239250
@patch("os.path.isfile", return_value=True)
240-
def test_git_clone_repo_with_token_2fa(isfile, mkdtemp, check_call):
251+
def test_git_clone_repo_with_token_2fa(isfile, tempdirm, mkdtemp, check_call):
241252
git_config = {
242253
"repo": PRIVATE_GIT_REPO,
243254
"branch": PRIVATE_BRANCH,
@@ -264,8 +275,10 @@ def test_git_clone_repo_with_token_2fa(isfile, mkdtemp, check_call):
264275
@patch("subprocess.check_call")
265276
@patch("os.chmod")
266277
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
278+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
267279
@patch("os.path.isfile", return_value=True)
268-
def test_git_clone_repo_ssh(isfile, mkdtemp, chmod, check_call):
280+
def test_git_clone_repo_ssh(isfile, tempdir, mkdtemp, chmod, check_call):
281+
Path(REPO_DIR).mkdir(parents=True, exist_ok=True)
269282
git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT}
270283
entry_point = "entry_point"
271284
ret = git_utils.git_clone_repo(git_config, entry_point)
@@ -277,8 +290,11 @@ def test_git_clone_repo_ssh(isfile, mkdtemp, chmod, check_call):
277290

278291
@patch("subprocess.check_call")
279292
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
293+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
280294
@patch("os.path.isfile", return_value=True)
281-
def test_git_clone_repo_with_token_no_2fa_unnecessary_creds_provided(isfile, mkdtemp, check_call):
295+
def test_git_clone_repo_with_token_no_2fa_unnecessary_creds_provided(
296+
isfile, tempdir, mkdtemp, check_call
297+
):
282298
git_config = {
283299
"repo": PRIVATE_GIT_REPO,
284300
"branch": PRIVATE_BRANCH,
@@ -309,8 +325,11 @@ def test_git_clone_repo_with_token_no_2fa_unnecessary_creds_provided(isfile, mkd
309325

310326
@patch("subprocess.check_call")
311327
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
328+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
312329
@patch("os.path.isfile", return_value=True)
313-
def test_git_clone_repo_with_token_2fa_unnecessary_creds_provided(isfile, mkdtemp, check_call):
330+
def test_git_clone_repo_with_token_2fa_unnecessary_creds_provided(
331+
isfile, tempdir, mkdtemp, check_call
332+
):
314333
git_config = {
315334
"repo": PRIVATE_GIT_REPO,
316335
"branch": PRIVATE_BRANCH,
@@ -346,7 +365,8 @@ def test_git_clone_repo_with_token_2fa_unnecessary_creds_provided(isfile, mkdtem
346365
),
347366
)
348367
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
349-
def test_git_clone_repo_with_username_and_password_wrong_creds(mkdtemp, check_call):
368+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
369+
def test_git_clone_repo_with_username_and_password_wrong_creds(tempdir, mkdtemp, check_call):
350370
git_config = {
351371
"repo": PRIVATE_GIT_REPO,
352372
"branch": PRIVATE_BRANCH,
@@ -370,7 +390,8 @@ def test_git_clone_repo_with_username_and_password_wrong_creds(mkdtemp, check_ca
370390
),
371391
)
372392
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
373-
def test_git_clone_repo_with_token_wrong_creds(mkdtemp, check_call):
393+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
394+
def test_git_clone_repo_with_token_wrong_creds(tempdir, mkdtemp, check_call):
374395
git_config = {
375396
"repo": PRIVATE_GIT_REPO,
376397
"branch": PRIVATE_BRANCH,
@@ -393,7 +414,8 @@ def test_git_clone_repo_with_token_wrong_creds(mkdtemp, check_call):
393414
),
394415
)
395416
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
396-
def test_git_clone_repo_with_and_token_2fa_wrong_creds(mkdtemp, check_call):
417+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
418+
def test_git_clone_repo_with_and_token_2fa_wrong_creds(tempdir, mkdtemp, check_call):
397419
git_config = {
398420
"repo": PRIVATE_GIT_REPO,
399421
"branch": PRIVATE_BRANCH,
@@ -411,8 +433,11 @@ def test_git_clone_repo_with_and_token_2fa_wrong_creds(mkdtemp, check_call):
411433

412434
@patch("subprocess.check_call")
413435
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
436+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
414437
@patch("os.path.isfile", return_value=True)
415-
def test_git_clone_repo_codecommit_https_with_username_and_password(isfile, mkdtemp, check_call):
438+
def test_git_clone_repo_codecommit_https_with_username_and_password(
439+
isfile, tempdir, mkdtemp, check_call
440+
):
416441
git_config = {
417442
"repo": CODECOMMIT_REPO,
418443
"branch": CODECOMMIT_BRANCH,
@@ -445,7 +470,9 @@ def test_git_clone_repo_codecommit_https_with_username_and_password(isfile, mkdt
445470
),
446471
)
447472
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
448-
def test_git_clone_repo_codecommit_ssh_passphrase_required(mkdtemp, check_call):
473+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
474+
def test_git_clone_repo_codecommit_ssh_passphrase_required(tempdir, mkdtemp, check_call):
475+
Path(REPO_DIR).mkdir(parents=True, exist_ok=True)
449476
git_config = {"repo": CODECOMMIT_REPO_SSH, "branch": CODECOMMIT_BRANCH}
450477
entry_point = "entry_point"
451478
with pytest.raises(subprocess.CalledProcessError) as error:
@@ -460,7 +487,8 @@ def test_git_clone_repo_codecommit_ssh_passphrase_required(mkdtemp, check_call):
460487
),
461488
)
462489
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
463-
def test_git_clone_repo_codecommit_https_creds_not_stored_locally(mkdtemp, check_call):
490+
@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR)
491+
def test_git_clone_repo_codecommit_https_creds_not_stored_locally(tempdir, mkdtemp, check_call):
464492
git_config = {"repo": CODECOMMIT_REPO, "branch": CODECOMMIT_BRANCH}
465493
entry_point = "entry_point"
466494
with pytest.raises(subprocess.CalledProcessError) as error:

0 commit comments

Comments
 (0)