Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6f5f107

Browse files
GaryTu1020chuyang-deng
authored andcommittedJul 8, 2019
feature: git support for hosting models (#878)
* git integration for serving
1 parent a32a846 commit 6f5f107

File tree

8 files changed

+354
-30
lines changed

8 files changed

+354
-30
lines changed
 

‎doc/overview.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,23 @@ After that, invoke the ``deploy()`` method on the ``Model``:
799799
800800
This returns a predictor the same way an ``Estimator`` does when ``deploy()`` is called. You can now get inferences just like with any other model deployed on Amazon SageMaker.
801801
802+
Git support is also available when you bring your own model, through which you can use inference scripts stored in your
803+
Git repositories. The process is similar to using Git support for training jobs. You can simply provide ``git_config``
804+
when create the ``Model`` object, and let ``entry_point``, ``source_dir`` and ``dependencies`` (if needed) be relative
805+
paths inside the Git repository:
806+
807+
.. code:: python
808+
809+
git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git',
810+
'branch': 'branch1',
811+
'commit': '4893e528afa4a790331e1b5286954f073b0f14a2'}
812+
813+
sagemaker_model = MXNetModel(model_data='s3://path/to/model.tar.gz',
814+
role='arn:aws:iam::accid:sagemaker-role',
815+
entry_point='inference.py',
816+
source_dir='mxnet',
817+
git_config=git_config)
818+
802819
A full example is available in the `Amazon SageMaker examples repository <https://github.com/awslabs/amazon-sagemaker-examples/tree/master/advanced_functionality/mxnet_mnist_byom>`__.
803820
804821
You can also find this notebook in the **Advanced Functionality** section of the **SageMaker Examples** section in a notebook instance.

‎src/sagemaker/estimator.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -955,8 +955,8 @@ def __init__(
955955
code_location=None,
956956
image_name=None,
957957
dependencies=None,
958-
git_config=None,
959958
enable_network_isolation=False,
959+
git_config=None,
960960
**kwargs
961961
):
962962
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
@@ -993,7 +993,7 @@ def __init__(
993993
source_dir (str): Path (absolute or relative) to a directory with any other training
994994
source code dependencies aside from the entry point file (default: None). Structure within this
995995
directory are preserved when training on Amazon SageMaker. If 'git_config' is provided,
996-
source_dir should be a relative location to a directory in the Git repo.
996+
'source_dir' should be a relative location to a directory in the Git repo.
997997
Example:
998998
999999
With the following GitHub repo directory structure:
@@ -1023,6 +1023,8 @@ def __init__(
10231023
dependencies (list[str]): A list of paths to directories (absolute or relative) with
10241024
any additional libraries that will be exported to the container (default: []).
10251025
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
1026+
If 'git_config' is provided, 'dependencies' should be a list of relative locations to directories
1027+
with any additional libraries needed in the Git repo.
10261028
Example:
10271029
10281030
The following call
@@ -1085,12 +1087,12 @@ def _prepare_for_training(self, job_name=None):
10851087
super(Framework, self)._prepare_for_training(job_name=job_name)
10861088

10871089
if self.git_config:
1088-
updates = git_utils.git_clone_repo(
1090+
updated_paths = git_utils.git_clone_repo(
10891091
self.git_config, self.entry_point, self.source_dir, self.dependencies
10901092
)
1091-
self.entry_point = updates["entry_point"]
1092-
self.source_dir = updates["source_dir"]
1093-
self.dependencies = updates["dependencies"]
1093+
self.entry_point = updated_paths["entry_point"]
1094+
self.source_dir = updated_paths["source_dir"]
1095+
self.dependencies = updated_paths["dependencies"]
10941096

10951097
# validate source dir will raise a ValueError if there is something wrong with the
10961098
# source directory. We are intentionally not handling it because this is a critical error.

‎src/sagemaker/git_utils.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import os
16+
import six
1617
import subprocess
1718
import tempfile
1819

@@ -39,35 +40,46 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
3940
3. failed to checkout the required commit
4041
ValueError: If 1. entry point specified does not exist in the repo
4142
2. source dir specified does not exist in the repo
43+
3. dependencies specified do not exist in the repo
44+
4. git_config is in bad format
4245
4346
Returns:
4447
dict: A dict that contains the updated values of entry_point, source_dir and dependencies
4548
"""
49+
if entry_point is None:
50+
raise ValueError("Please provide an entry point.")
4651
_validate_git_config(git_config)
4752
repo_dir = tempfile.mkdtemp()
4853
subprocess.check_call(["git", "clone", git_config["repo"], repo_dir])
4954

5055
_checkout_branch_and_commit(git_config, repo_dir)
5156

52-
ret = {"entry_point": entry_point, "source_dir": source_dir, "dependencies": dependencies}
57+
updated_paths = {
58+
"entry_point": entry_point,
59+
"source_dir": source_dir,
60+
"dependencies": dependencies,
61+
}
62+
5363
# check if the cloned repo contains entry point, source directory and dependencies
5464
if source_dir:
5565
if not os.path.isdir(os.path.join(repo_dir, source_dir)):
5666
raise ValueError("Source directory does not exist in the repo.")
5767
if not os.path.isfile(os.path.join(repo_dir, source_dir, entry_point)):
5868
raise ValueError("Entry point does not exist in the repo.")
59-
ret["source_dir"] = os.path.join(repo_dir, source_dir)
69+
updated_paths["source_dir"] = os.path.join(repo_dir, source_dir)
6070
else:
61-
if not os.path.isfile(os.path.join(repo_dir, entry_point)):
71+
if os.path.isfile(os.path.join(repo_dir, entry_point)):
72+
updated_paths["entry_point"] = os.path.join(repo_dir, entry_point)
73+
else:
6274
raise ValueError("Entry point does not exist in the repo.")
63-
ret["entry_point"] = os.path.join(repo_dir, entry_point)
6475

65-
ret["dependencies"] = []
76+
updated_paths["dependencies"] = []
6677
for path in dependencies:
67-
if not os.path.exists(os.path.join(repo_dir, path)):
78+
if os.path.exists(os.path.join(repo_dir, path)):
79+
updated_paths["dependencies"].append(os.path.join(repo_dir, path))
80+
else:
6881
raise ValueError("Dependency {} does not exist in the repo.".format(path))
69-
ret["dependencies"].append(os.path.join(repo_dir, path))
70-
return ret
82+
return updated_paths
7183

7284

7385
def _validate_git_config(git_config):
@@ -84,6 +96,13 @@ def _validate_git_config(git_config):
8496
"""
8597
if "repo" not in git_config:
8698
raise ValueError("Please provide a repo for git_config.")
99+
allowed_keys = ["repo", "branch", "commit"]
100+
for key in allowed_keys:
101+
if key in git_config and not isinstance(git_config[key], six.string_types):
102+
raise ValueError("'{}' should be a string".format(key))
103+
for key in git_config:
104+
if key not in allowed_keys:
105+
raise ValueError("Unexpected argument(s) provided for git_config!")
87106

88107

89108
def _checkout_branch_and_commit(git_config, repo_dir):
@@ -95,8 +114,8 @@ def _checkout_branch_and_commit(git_config, repo_dir):
95114
repo_dir (str): the directory where the repo is cloned
96115
97116
Raises:
98-
ValueError: If 1. entry point specified does not exist in the repo
99-
2. source dir specified does not exist in the repo
117+
CalledProcessError: If 1. failed to checkout the required branch
118+
2. failed to checkout the required commit
100119
"""
101120
if "branch" in git_config:
102121
subprocess.check_call(args=["git", "checkout", git_config["branch"]], cwd=str(repo_dir))

‎src/sagemaker/model.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818

1919
import sagemaker
20-
from sagemaker import fw_utils, local, session, utils
20+
from sagemaker import fw_utils, local, session, utils, git_utils
2121
from sagemaker.fw_utils import UploadedCode
2222
from sagemaker.transformer import Transformer
2323

@@ -494,6 +494,7 @@ def __init__(
494494
code_location=None,
495495
sagemaker_session=None,
496496
dependencies=None,
497+
git_config=None,
497498
**kwargs
498499
):
499500
"""Initialize a ``FrameworkModel``.
@@ -504,15 +505,54 @@ def __init__(
504505
role (str): An IAM role name or ARN for SageMaker to access AWS resources on your behalf.
505506
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
506507
as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
508+
If 'git_config' is provided, 'entry_point' should be a relative location to the Python source file in
509+
the Git repo.
510+
Example:
511+
512+
With the following GitHub repo directory structure:
513+
514+
>>> |----- README.md
515+
>>> |----- src
516+
>>> |----- inference.py
517+
>>> |----- test.py
518+
519+
You can assign entry_point='src/inference.py'.
520+
git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch'
521+
and 'commit' (default: None).
522+
'branch' and 'commit' are optional. If 'branch' is not specified, 'master' branch will be used. If
523+
'commit' is not specified, the latest commit in the required branch will be used.
524+
Example:
525+
526+
The following config:
527+
528+
>>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git',
529+
>>> 'branch': 'test-branch-git-config',
530+
>>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'}
531+
532+
results in cloning the repo specified in 'repo', then checkout the 'master' branch, and checkout
533+
the specified commit.
507534
source_dir (str): Path (absolute or relative) to a directory with any other training
508535
source code dependencies aside from the entry point file (default: None). Structure within this
509-
directory will be preserved when training on SageMaker.
510-
If the directory points to S3, no code will be uploaded and the S3 location will be used instead.
536+
directory will be preserved when training on SageMaker. If 'git_config' is provided,
537+
'source_dir' should be a relative location to a directory in the Git repo. If the directory points
538+
to S3, no code will be uploaded and the S3 location will be used instead.
539+
Example:
540+
541+
With the following GitHub repo directory structure:
542+
543+
>>> |----- README.md
544+
>>> |----- src
545+
>>> |----- inference.py
546+
>>> |----- test.py
547+
548+
You can assign entry_point='inference.py', source_dir='src'.
511549
dependencies (list[str]): A list of paths to directories (absolute or relative) with
512550
any additional libraries that will be exported to the container (default: []).
513551
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
514-
If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
515-
instead. Example:
552+
If 'git_config' is provided, 'dependencies' should be a list of relative locations to directories
553+
with any additional libraries needed in the Git repo. If the ```source_dir``` points to S3, code
554+
will be uploaded and the S3 location will be used instead.
555+
Example:
516556
517557
The following call
518558
>>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
@@ -554,12 +594,20 @@ def __init__(
554594
self.entry_point = entry_point
555595
self.source_dir = source_dir
556596
self.dependencies = dependencies or []
597+
self.git_config = git_config
557598
self.enable_cloudwatch_metrics = enable_cloudwatch_metrics
558599
self.container_log_level = container_log_level
559600
if code_location:
560601
self.bucket, self.key_prefix = fw_utils.parse_s3_url(code_location)
561602
else:
562603
self.bucket, self.key_prefix = None, None
604+
if self.git_config:
605+
updates = git_utils.git_clone_repo(
606+
self.git_config, self.entry_point, self.source_dir, self.dependencies
607+
)
608+
self.entry_point = updates["entry_point"]
609+
self.source_dir = updates["source_dir"]
610+
self.dependencies = updates["dependencies"]
563611
self.uploaded_code = None
564612
self.repacked_model_data = None
565613

‎tests/integ/test_git.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
from tests.integ import lock as lock
2222
from sagemaker.mxnet.estimator import MXNet
2323
from sagemaker.pytorch.estimator import PyTorch
24+
from sagemaker.sklearn.estimator import SKLearn
25+
from sagemaker.mxnet.model import MXNetModel
26+
from sagemaker.sklearn.model import SKLearnModel
2427
from tests.integ import DATA_DIR, PYTHON_VERSION
2528

2629
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
2730
BRANCH = "test-branch-git-config"
28-
COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a"
31+
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
2932

3033
# endpoint tests all use the same port, so we use this lock to prevent concurrent execution
3134
LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_git_lock")
@@ -62,15 +65,16 @@ def test_git_support_with_pytorch(sagemaker_local_session):
6265

6366

6467
@pytest.mark.local_mode
65-
def test_git_support_with_mxnet(sagemaker_local_session, mxnet_full_version):
68+
def test_git_support_with_mxnet(sagemaker_local_session):
6669
script_path = "mnist.py"
6770
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
6871
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
72+
source_dir = "mxnet"
6973
dependencies = ["foo/bar.py"]
7074
mx = MXNet(
7175
entry_point=script_path,
7276
role="SageMakerRole",
73-
source_dir="mxnet",
77+
source_dir=source_dir,
7478
dependencies=dependencies,
7579
framework_version=MXNet.LATEST_VERSION,
7680
py_version=PYTHON_VERSION,
@@ -94,10 +98,76 @@ def test_git_support_with_mxnet(sagemaker_local_session, mxnet_full_version):
9498

9599
with lock.lock(LOCK_PATH):
96100
try:
97-
predictor = mx.deploy(initial_instance_count=1, instance_type="local")
101+
serving_script_path = "mnist_hosting_with_custom_handlers.py"
102+
client = sagemaker_local_session.sagemaker_client
103+
desc = client.describe_training_job(TrainingJobName=mx.latest_training_job.name)
104+
model_data = desc["ModelArtifacts"]["S3ModelArtifacts"]
105+
model = MXNetModel(
106+
model_data,
107+
"SageMakerRole",
108+
entry_point=serving_script_path,
109+
source_dir=source_dir,
110+
dependencies=dependencies,
111+
py_version=PYTHON_VERSION,
112+
sagemaker_session=sagemaker_local_session,
113+
framework_version=MXNet.LATEST_VERSION,
114+
git_config=git_config,
115+
)
116+
predictor = model.deploy(initial_instance_count=1, instance_type="local")
98117

99118
data = numpy.zeros(shape=(1, 1, 28, 28))
100119
result = predictor.predict(data)
101120
assert result is not None
102121
finally:
103122
predictor.delete_endpoint()
123+
124+
125+
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.")
126+
@pytest.mark.local_mode
127+
def test_git_support_with_sklearn(sagemaker_local_session, sklearn_full_version):
128+
script_path = "mnist.py"
129+
data_path = os.path.join(DATA_DIR, "sklearn_mnist")
130+
git_config = {
131+
"repo": "https://github.com/GaryTu1020/python-sdk-testing.git",
132+
"branch": "branch1",
133+
"commit": "aafa4e96237dd78a015d5df22bfcfef46845c3c5",
134+
}
135+
source_dir = "sklearn"
136+
sklearn = SKLearn(
137+
entry_point=script_path,
138+
role="SageMakerRole",
139+
source_dir=source_dir,
140+
py_version=PYTHON_VERSION,
141+
train_instance_count=1,
142+
train_instance_type="local",
143+
sagemaker_session=sagemaker_local_session,
144+
framework_version=sklearn_full_version,
145+
hyperparameters={"epochs": 1},
146+
git_config=git_config,
147+
)
148+
train_input = "file://" + os.path.join(data_path, "train")
149+
test_input = "file://" + os.path.join(data_path, "test")
150+
sklearn.fit({"train": train_input, "test": test_input})
151+
152+
assert os.path.isdir(sklearn.source_dir)
153+
154+
with lock.lock(LOCK_PATH):
155+
try:
156+
client = sagemaker_local_session.sagemaker_client
157+
desc = client.describe_training_job(TrainingJobName=sklearn.latest_training_job.name)
158+
model_data = desc["ModelArtifacts"]["S3ModelArtifacts"]
159+
model = SKLearnModel(
160+
model_data,
161+
"SageMakerRole",
162+
entry_point=script_path,
163+
source_dir=source_dir,
164+
sagemaker_session=sagemaker_local_session,
165+
git_config=git_config,
166+
)
167+
predictor = model.deploy(1, "local")
168+
169+
data = numpy.zeros((100, 784), dtype="float32")
170+
result = predictor.predict(data)
171+
assert result is not None
172+
finally:
173+
predictor.delete_endpoint()

‎tests/unit/test_estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
OUTPUT_PATH = "s3://bucket/prefix"
5151
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
5252
BRANCH = "test-branch-git-config"
53-
COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a"
53+
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
5454

5555
DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}}
5656
INSTANCE_TYPE = "c4.4xlarge"
@@ -898,12 +898,12 @@ def test_git_support_bad_repo_url_format(sagemaker_session):
898898

899899

900900
@patch(
901-
"subprocess.check_call",
901+
"sagemaker.git_utils.git_clone_repo",
902902
side_effect=subprocess.CalledProcessError(
903-
returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git"
903+
returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir"
904904
),
905905
)
906-
def test_git_support_git_clone_fail(check_call, sagemaker_session):
906+
def test_git_support_git_clone_fail(sagemaker_session):
907907
git_config = {"repo": "https://github.com/aws/no-such-repo.git", "branch": BRANCH}
908908
fw = DummyFramework(
909909
entry_point="entry_point",

‎tests/unit/test_git_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
REPO_DIR = "/tmp/repo_dir"
2222
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
2323
BRANCH = "test-branch-git-config"
24-
COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a"
24+
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
2525

2626

2727
@patch("subprocess.check_call")
@@ -44,6 +44,14 @@ def test_git_clone_repo_succeed(exists, isdir, isfile, mkdtemp, check_call):
4444
assert ret["dependencies"] == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"]
4545

4646

47+
def test_git_clone_repo_entry_point_not_provided():
48+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
49+
source_dir = "source_dir"
50+
with pytest.raises(ValueError) as error:
51+
git_utils.git_clone_repo(git_config=git_config, entry_point=None, source_dir=source_dir)
52+
assert "Please provide an entry point." in str(error)
53+
54+
4755
@patch("subprocess.check_call")
4856
@patch("tempfile.mkdtemp", return_value=REPO_DIR)
4957
@patch("os.path.isfile", return_value=True)

‎tests/unit/test_model.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import copy
1616
import os
17+
import subprocess
1718

1819
import sagemaker
1920
from sagemaker.model import FrameworkModel, ModelPackage
@@ -39,6 +40,9 @@
3940
IMAGE_NAME = "fakeimage"
4041
REGION = "us-west-2"
4142
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
43+
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
44+
BRANCH = "test-branch-git-config"
45+
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
4246

4347

4448
DESCRIBE_MODEL_PACKAGE_RESPONSE = {
@@ -94,6 +98,21 @@ def create_predictor(self, endpoint_name):
9498
return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session)
9599

96100

101+
class DummyFrameworkModelForGit(FrameworkModel):
102+
def __init__(self, sagemaker_session, entry_point, **kwargs):
103+
super(DummyFrameworkModelForGit, self).__init__(
104+
MODEL_DATA,
105+
MODEL_IMAGE,
106+
ROLE,
107+
entry_point=entry_point,
108+
sagemaker_session=sagemaker_session,
109+
**kwargs
110+
)
111+
112+
def create_predictor(self, endpoint_name):
113+
return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session)
114+
115+
97116
@pytest.fixture()
98117
def sagemaker_session():
99118
boto_mock = Mock(name="boto_session", region_name=REGION)
@@ -506,3 +525,144 @@ def test_check_neo_region(sagemaker_session, tmpdir):
506525
assert model.check_neo_region(region_name) is True
507526
else:
508527
assert model.check_neo_region(region_name) is False
528+
529+
530+
@patch("sagemaker.git_utils.git_clone_repo")
531+
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
532+
def test_git_support_succeed(tar_and_upload_dir, git_clone_repo, sagemaker_session):
533+
git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: {
534+
"entry_point": "entry_point",
535+
"source_dir": "/tmp/repo_dir/source_dir",
536+
"dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"],
537+
}
538+
entry_point = "entry_point"
539+
source_dir = "source_dir"
540+
dependencies = ["foo", "bar"]
541+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
542+
model = DummyFrameworkModelForGit(
543+
sagemaker_session=sagemaker_session,
544+
entry_point=entry_point,
545+
source_dir=source_dir,
546+
dependencies=dependencies,
547+
git_config=git_config,
548+
)
549+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
550+
git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies)
551+
assert model.entry_point == "entry_point"
552+
assert model.source_dir == "/tmp/repo_dir/source_dir"
553+
assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"]
554+
555+
556+
def test_git_support_repo_not_provided(sagemaker_session):
557+
entry_point = "source_dir/entry_point"
558+
git_config = {"branch": BRANCH, "commit": COMMIT}
559+
with pytest.raises(ValueError) as error:
560+
model = DummyFrameworkModelForGit(
561+
sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config
562+
)
563+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
564+
assert "Please provide a repo for git_config." in str(error)
565+
566+
567+
@patch(
568+
"sagemaker.git_utils.git_clone_repo",
569+
side_effect=subprocess.CalledProcessError(
570+
returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir"
571+
),
572+
)
573+
def test_git_support_git_clone_fail(sagemaker_session):
574+
entry_point = "source_dir/entry_point"
575+
git_config = {"repo": "https://github.com/aws/no-such-repo.git", "branch": BRANCH}
576+
with pytest.raises(subprocess.CalledProcessError) as error:
577+
model = DummyFrameworkModelForGit(
578+
sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config
579+
)
580+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
581+
assert "returned non-zero exit status" in str(error)
582+
583+
584+
@patch(
585+
"sagemaker.git_utils.git_clone_repo",
586+
side_effect=subprocess.CalledProcessError(
587+
returncode=1, cmd="git checkout branch-that-does-not-exist"
588+
),
589+
)
590+
def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session):
591+
entry_point = "source_dir/entry_point"
592+
git_config = {"repo": GIT_REPO, "branch": "branch-that-does-not-exist", "commit": COMMIT}
593+
with pytest.raises(subprocess.CalledProcessError) as error:
594+
model = DummyFrameworkModelForGit(
595+
sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config
596+
)
597+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
598+
assert "returned non-zero exit status" in str(error)
599+
600+
601+
@patch(
602+
"sagemaker.git_utils.git_clone_repo",
603+
side_effect=subprocess.CalledProcessError(
604+
returncode=1, cmd="git checkout commit-sha-that-does-not-exist"
605+
),
606+
)
607+
def test_git_support_commit_not_exist(git_clone_repo, sagemaker_session):
608+
entry_point = "source_dir/entry_point"
609+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": "commit-sha-that-does-not-exist"}
610+
with pytest.raises(subprocess.CalledProcessError) as error:
611+
model = DummyFrameworkModelForGit(
612+
sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config
613+
)
614+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
615+
assert "returned non-zero exit status" in str(error)
616+
617+
618+
@patch(
619+
"sagemaker.git_utils.git_clone_repo",
620+
side_effect=ValueError("Entry point does not exist in the repo."),
621+
)
622+
def test_git_support_entry_point_not_exist(sagemaker_session):
623+
entry_point = "source_dir/entry_point"
624+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
625+
with pytest.raises(ValueError) as error:
626+
model = DummyFrameworkModelForGit(
627+
sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config
628+
)
629+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
630+
assert "Entry point does not exist in the repo." in str(error)
631+
632+
633+
@patch(
634+
"sagemaker.git_utils.git_clone_repo",
635+
side_effect=ValueError("Source directory does not exist in the repo."),
636+
)
637+
def test_git_support_source_dir_not_exist(sagemaker_session):
638+
entry_point = "entry_point"
639+
source_dir = "source_dir_that_does_not_exist"
640+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
641+
with pytest.raises(ValueError) as error:
642+
model = DummyFrameworkModelForGit(
643+
sagemaker_session=sagemaker_session,
644+
entry_point=entry_point,
645+
source_dir=source_dir,
646+
git_config=git_config,
647+
)
648+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
649+
assert "Source directory does not exist in the repo." in str(error)
650+
651+
652+
@patch(
653+
"sagemaker.git_utils.git_clone_repo",
654+
side_effect=ValueError("Dependency no-such-dir does not exist in the repo."),
655+
)
656+
def test_git_support_dependencies_not_exist(sagemaker_session):
657+
entry_point = "entry_point"
658+
dependencies = ["foo", "no_such_dir"]
659+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
660+
with pytest.raises(ValueError) as error:
661+
model = DummyFrameworkModelForGit(
662+
sagemaker_session=sagemaker_session,
663+
entry_point=entry_point,
664+
dependencies=dependencies,
665+
git_config=git_config,
666+
)
667+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
668+
assert "Dependency", "does not exist in the repo." in str(error)

0 commit comments

Comments
 (0)
Please sign in to comment.