Skip to content

Commit 4b0e14c

Browse files
committed
Merge branch 'master' of https://github.com/yl-to/sagemaker-python-sdk into torchrun_gpu_sup
2 parents ea567d4 + 471ee25 commit 4b0e14c

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

requirements/extras/test_requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ black==22.3.0
1313
stopit==1.1.2
1414
# Update tox.ini to have correct version of airflow constraints file
1515
apache-airflow==2.4.1
16-
apache-airflow-providers-amazon==4.0.0
16+
apache-airflow-providers-amazon==7.2.1
1717
attrs==22.1.0
1818
fabric==2.6.0
1919
requests==2.27.1

src/sagemaker/feature_store/feature_group.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -473,11 +473,12 @@ class FeatureGroup:
473473
Attributes:
474474
name (str): name of the FeatureGroup instance.
475475
sagemaker_session (Session): session instance to perform boto calls.
476+
If None, a new Session will be created.
476477
feature_definitions (Sequence[FeatureDefinition]): list of FeatureDefinitions.
477478
"""
478479

479480
name: str = attr.ib(factory=str)
480-
sagemaker_session: Session = attr.ib(default=Session)
481+
sagemaker_session: Session = attr.ib(factory=Session)
481482
feature_definitions: Sequence[FeatureDefinition] = attr.ib(factory=list)
482483

483484
_INTEGER_TYPES = [

src/sagemaker/git_utils.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def _clone_command_for_github_like(git_config, dest_dir):
174174
CalledProcessError: If failed to clone git repo.
175175
"""
176176
is_https = git_config["repo"].startswith("https://")
177-
is_ssh = git_config["repo"].startswith("git@")
177+
is_ssh = git_config["repo"].startswith("git@") or git_config["repo"].startswith("ssh://")
178178
if not is_https and not is_ssh:
179179
raise ValueError("Invalid Git url provided.")
180180
if is_ssh:
@@ -277,12 +277,16 @@ def _run_clone_command(repo_url, dest_dir):
277277
if repo_url.startswith("https://"):
278278
my_env["GIT_TERMINAL_PROMPT"] = "0"
279279
subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)
280-
elif repo_url.startswith("git@"):
281-
with tempfile.NamedTemporaryFile() as sshnoprompt:
282-
with open(sshnoprompt.name, "w") as write_pipe:
283-
write_pipe.write("ssh -oBatchMode=yes $@")
284-
os.chmod(sshnoprompt.name, 0o511)
285-
my_env["GIT_SSH"] = sshnoprompt.name
280+
elif repo_url.startswith("git@") or repo_url.startswith("ssh://"):
281+
try:
282+
with tempfile.NamedTemporaryFile() as sshnoprompt:
283+
with open(sshnoprompt.name, "w") as write_pipe:
284+
write_pipe.write("ssh -oBatchMode=yes $@")
285+
os.chmod(sshnoprompt.name, 0o511)
286+
my_env["GIT_SSH"] = sshnoprompt.name
287+
subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)
288+
except subprocess.CalledProcessError:
289+
del my_env["GIT_SSH"]
286290
subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)
287291

288292

0 commit comments

Comments
 (0)