Skip to content

Commit 1441e9f

Browse files
committed
change: add git test for estimator class, update typing
1 parent f27d5a7 commit 1441e9f

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import os
19+
from typing import Dict
1920
import uuid
2021
from abc import ABCMeta, abstractmethod
2122

@@ -582,7 +583,7 @@ def _get_or_create_name(self, name=None):
582583
return name_from_base(self.base_job_name)
583584

584585
@staticmethod
585-
def _json_encode_hyperparameters(hyperparameters):
586+
def _json_encode_hyperparameters(hyperparameters: dict) -> dict:
586587
"""Applies Json encoding for certain Hyperparameter types, returns hyperparameters.
587588
588589
Args:
@@ -661,14 +662,14 @@ def _prepare_for_training(self, job_name=None):
661662
self._prepare_debugger_for_training()
662663
self._prepare_profiler_for_training()
663664

664-
def _script_mode_hyperparam_update(self, code_dir, script):
665+
def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None:
665666
"""Applies in-place update to hyperparameters required for script mode with training.
666667
667668
Args:
668669
code_dir (str): The directory hosting the training scripts.
669670
script (str): The relative filepath of the training entry-point script.
670671
"""
671-
hyperparams = {}
672+
hyperparams: Dict[str, str] = {}
672673
hyperparams[DIR_PARAM_NAME] = code_dir
673674
hyperparams[SCRIPT_PARAM_NAME] = script
674675
hyperparams[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
@@ -677,7 +678,7 @@ def _script_mode_hyperparam_update(self, code_dir, script):
677678

678679
self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(hyperparams))
679680

680-
def _stage_user_code_in_s3(self):
681+
def _stage_user_code_in_s3(self) -> str:
681682
"""Upload the user training script to s3 and return the location.
682683
683684
Returns: s3 uri
@@ -2615,14 +2616,14 @@ def _prepare_for_training(self, job_name=None):
26152616

26162617
self._validate_and_set_debugger_configs()
26172618

2618-
def _script_mode_hyperparam_update(self, code_dir, script):
2619+
def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None:
26192620
"""Applies in-place update to hyperparameters required for script mode with training.
26202621
26212622
Args:
26222623
code_dir (str): The directory hosting the training scripts.
26232624
script (str): The relative filepath of the training entry-point script.
26242625
"""
2625-
hyperparams = {}
2626+
hyperparams: Dict[str, str] = {}
26262627
hyperparams[DIR_PARAM_NAME] = code_dir
26272628
hyperparams[SCRIPT_PARAM_NAME] = script
26282629
hyperparams[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level

tests/unit/test_estimator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3354,10 +3354,34 @@ def test_image_name_map(sagemaker_session):
33543354
assert e.image_uri == IMAGE_URI
33553355

33563356

3357+
@patch("sagemaker.git_utils.git_clone_repo")
3358+
def test_git_support_with_branch_and_commit_succeed_estimator_class(
3359+
git_clone_repo, sagemaker_session
3360+
):
3361+
git_clone_repo.side_effect = lambda gitconfig, entrypoint, source_dir=None, dependencies=None: {
3362+
"entry_point": "/tmp/repo_dir/entry_point",
3363+
"source_dir": None,
3364+
"dependencies": None,
3365+
}
3366+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
3367+
entry_point = "entry_point"
3368+
fw = Estimator(
3369+
entry_point=entry_point,
3370+
git_config=git_config,
3371+
role=ROLE,
3372+
sagemaker_session=sagemaker_session,
3373+
instance_count=INSTANCE_COUNT,
3374+
instance_type=INSTANCE_TYPE,
3375+
image_uri=IMAGE_URI,
3376+
)
3377+
fw.fit()
3378+
git_clone_repo.assert_called_once_with(git_config, entry_point, None, None)
3379+
3380+
33573381
@patch("sagemaker.estimator.Estimator._stage_user_code_in_s3")
33583382
def test_script_mode_estimator(patched_stage_user_code, sagemaker_session):
33593383
patched_stage_user_code.return_value = UploadedCode(
3360-
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
3384+
s3_prefix="s3://bucket/key", script_name="script_name"
33613385
)
33623386
script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz"
33633387
image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38"

0 commit comments

Comments
 (0)