Skip to content

Commit 574b9a2

Browse files
committed
change: update _upload_code typing, add git test for model class
1 parent 169d8d1 commit 574b9a2

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

src/sagemaker/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def prepare_container_def(
405405
self.image_uri, self.model_data, deploy_env, image_config=self.image_config
406406
)
407407

408-
def _upload_code(self, key_prefix, repack=False):
408+
def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
409409
"""Uploads code to S3 to be used with script mode with SageMaker inference.
410410
411411
Args:

tests/unit/test_model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
2929
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
3030
MODEL_DATA = "s3://someprefix2/models/model.tar.gz"
31+
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
32+
BRANCH = "test-branch-git-config"
33+
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
3134

3235

3336
class DummyFrameworkModel(FrameworkModel):
@@ -97,3 +100,30 @@ def test_script_mode_model_same_calls_as_framework(repack_model, sagemaker_sessi
97100
== sagemaker_session.endpoint_from_production_variants.call_args_list
98101
)
99102
assert generic_model_repack_model_args == repack_model.call_args_list
103+
104+
105+
@patch("sagemaker.git_utils.git_clone_repo")
106+
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
107+
def test_git_support_succeed_model_class(tar_and_upload_dir, git_clone_repo, sagemaker_session):
108+
git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: {
109+
"entry_point": "entry_point",
110+
"source_dir": "/tmp/repo_dir/source_dir",
111+
"dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"],
112+
}
113+
entry_point = "entry_point"
114+
source_dir = "source_dir"
115+
dependencies = ["foo", "bar"]
116+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
117+
model = Model(
118+
sagemaker_session=sagemaker_session,
119+
entry_point=entry_point,
120+
source_dir=source_dir,
121+
dependencies=dependencies,
122+
git_config=git_config,
123+
image_uri=IMAGE_URI,
124+
)
125+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
126+
git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies)
127+
assert model.entry_point == "entry_point"
128+
assert model.source_dir == "/tmp/repo_dir/source_dir"
129+
assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"]

0 commit comments

Comments
 (0)