Skip to content

Commit 1228f65

Browse files
GaryTu1020pengk19
authored andcommitted
feature: add git_config and git_clone, validate method (aws#832)
1 parent 1f0c265 commit 1228f65

File tree

6 files changed

+739
-4
lines changed

6 files changed

+739
-4
lines changed

doc/overview.rst

+60
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,65 @@ For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/r
8484
# Deletes the SageMaker model
8585
mxnet_predictor.delete_model()
8686
87+
Git Support
88+
~~~~~~~~~~~
89+
If you have your training scripts in your GitHub repository, you can use them directly without the trouble to download
90+
them to local machine. Git support can be enabled simply by providing ``git_config`` parameter when initializing an
91+
estimator. If Git support is enabled, then ``entry_point``, ``source_dir`` and ``dependencies`` should all be relative
92+
paths in the Git repo. Note that if you decided to use Git support, then everything you need for ``entry_point``,
93+
``source_dir`` and ``dependencies`` should be in a single Git repo.
94+
95+
Here are ways to specify ``git_config``:
96+
97+
.. code:: python
98+
99+
# Specifies the git_config parameter
100+
git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git',
101+
'branch': 'branch1',
102+
'commit': '4893e528afa4a790331e1b5286954f073b0f14a2'}
103+
104+
# Alternatively, you can also specify git_config by providing only 'repo' and 'branch'.
105+
# If this is the case, the latest commit in the branch will be used.
106+
git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git',
107+
'branch': 'branch1'}
108+
109+
# Only providing 'repo' is also allowed. If this is the case, latest commit in
110+
# 'master' branch will be used.
111+
git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git'
112+
113+
The following are some examples to define estimators with Git support:
114+
115+
.. code:: python
116+
117+
# In this example, the source directory 'pytorch' contains the entry point 'mnist.py' and other source code.
118+
# and it is relative path inside the Git repo.
119+
pytorch_estimator = PyTorch(entry_point='mnist.py',
120+
role='SageMakerRole',
121+
source_dir='pytorch',
122+
git_config=git_config,
123+
train_instance_count=1,
124+
train_instance_type='ml.c4.xlarge')
125+
126+
# In this example, the entry point 'mnist.py' is all we need for source code.
127+
# We need to specify the path to it in the Git repo.
128+
mx_estimator = MXNet(entry_point='mxnet/mnist.py',
129+
role='SageMakerRole',
130+
git_config=git_config,
131+
train_instance_count=1,
132+
train_instance_type='ml.c4.xlarge')
133+
134+
# In this example, besides entry point and other source code in source directory, we still need some
135+
# dependencies for the training job. Dependencies should also be paths inside the Git repo.
136+
pytorch_estimator = PyTorch(entry_point='mnist.py',
137+
role='SageMakerRole',
138+
source_dir='pytorch',
139+
dependencies=['dep.py', 'foo/bar.py'],
140+
git_config=git_config,
141+
train_instance_count=1,
142+
train_instance_type='ml.c4.xlarge')
143+
144+
When Git support is enabled, users can still use local mode in the same way.
145+
87146
Training Metrics
88147
~~~~~~~~~~~~~~~~
89148
The SageMaker Python SDK allows you to specify a name and a regular expression for metrics you want to track for training.
@@ -268,6 +327,7 @@ Currently, the following algorithms support incremental training:
268327
- Object Detection
269328
- Semantic Segmentation
270329
330+
271331
Using SageMaker AlgorithmEstimators
272332
-----------------------------------
273333

src/sagemaker/estimator.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from six import string_types
2323

2424
import sagemaker
25+
from sagemaker import git_utils
2526
from sagemaker.analytics import TrainingJobAnalytics
2627
from sagemaker.fw_utils import (
2728
create_image_uri,
@@ -933,6 +934,7 @@ class Framework(EstimatorBase):
933934
"""
934935

935936
__framework_name__ = None
937+
936938
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
937939
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
938940
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
@@ -949,6 +951,7 @@ def __init__(
949951
code_location=None,
950952
image_name=None,
951953
dependencies=None,
954+
git_config=None,
952955
enable_network_isolation=False,
953956
**kwargs
954957
):
@@ -957,9 +960,47 @@ def __init__(
957960
Args:
958961
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
959962
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
963+
If 'git_config' is provided, 'entry_point' should be a relative location to the Python source file in
964+
the Git repo.
965+
Example:
966+
967+
With the following GitHub repo directory structure:
968+
969+
>>> |----- README.md
970+
>>> |----- src
971+
>>> |----- train.py
972+
>>> |----- test.py
973+
974+
You can assign entry_point='src/train.py'.
975+
git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch'
976+
and 'commit' (default: None).
977+
'branch' and 'commit' are optional. If 'branch' is not specified, 'master' branch will be used. If
978+
'commit' is not specified, the latest commit in the required branch will be used.
979+
Example:
980+
981+
The following config:
982+
983+
>>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git',
984+
>>> 'branch': 'test-branch-git-config',
985+
>>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'}
986+
987+
results in cloning the repo specified in 'repo', then checkout the 'master' branch, and checkout
988+
the specified commit.
960989
source_dir (str): Path (absolute or relative) to a directory with any other training
961990
source code dependencies aside from the entry point file (default: None). Structure within this
962-
directory are preserved when training on Amazon SageMaker.
991+
directory are preserved when training on Amazon SageMaker. If 'git_config' is provided,
992+
source_dir should be a relative location to a directory in the Git repo.
993+
Example:
994+
995+
With the following GitHub repo directory structure:
996+
997+
>>> |----- README.md
998+
>>> |----- src
999+
>>> |----- train.py
1000+
>>> |----- test.py
1001+
1002+
and you need 'train.py' as entry point and 'test.py' as training source code as well, you can
1003+
assign entry_point='train.py', source_dir='src'.
9631004
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
9641005
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
9651006
For convenience, this accepts other types for keys and values, but ``str()`` will be called
@@ -1006,6 +1047,7 @@ def __init__(
10061047
)
10071048
)
10081049
self.entry_point = entry_point
1050+
self.git_config = git_config
10091051
self.source_dir = source_dir
10101052
self.dependencies = dependencies or []
10111053
if enable_cloudwatch_metrics:
@@ -1038,6 +1080,14 @@ def _prepare_for_training(self, job_name=None):
10381080
"""
10391081
super(Framework, self)._prepare_for_training(job_name=job_name)
10401082

1083+
if self.git_config:
1084+
updates = git_utils.git_clone_repo(
1085+
self.git_config, self.entry_point, self.source_dir, self.dependencies
1086+
)
1087+
self.entry_point = updates["entry_point"]
1088+
self.source_dir = updates["source_dir"]
1089+
self.dependencies = updates["dependencies"]
1090+
10411091
# validate source dir will raise a ValueError if there is something wrong with the
10421092
# source directory. We are intentionally not handling it because this is a critical error.
10431093
if self.source_dir and not self.source_dir.lower().startswith("s3://"):

src/sagemaker/git_utils.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import subprocess
17+
import tempfile
18+
19+
20+
def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
21+
"""Git clone repo containing the training code and serving code. This method also validate ``git_config``,
22+
and set ``entry_point``, ``source_dir`` and ``dependencies`` to the right file or directory in the repo cloned.
23+
24+
Args:
25+
git_config (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch``
26+
and ``commit``. ``branch`` and ``commit`` are optional. If ``branch`` is not specified, master branch
27+
will be used. If ``commit`` is not specified, the latest commit in the required branch will be used.
28+
entry_point (str): A relative location to the Python source file which should be executed as the entry point
29+
to training or model hosting in the Git repo.
30+
source_dir (str): A relative location to a directory with other training or model hosting source code
31+
dependencies aside from the entry point file in the Git repo (default: None). Structure within this
32+
directory are preserved when training on Amazon SageMaker.
33+
dependencies (list[str]): A list of relative locations to directories with any additional libraries that will
34+
be exported to the container in the Git repo (default: []).
35+
36+
Raises:
37+
CalledProcessError: If 1. failed to clone git repo
38+
2. failed to checkout the required branch
39+
3. failed to checkout the required commit
40+
ValueError: If 1. entry point specified does not exist in the repo
41+
2. source dir specified does not exist in the repo
42+
43+
Returns:
44+
dict: A dict that contains the updated values of entry_point, source_dir and dependencies
45+
"""
46+
_validate_git_config(git_config)
47+
repo_dir = tempfile.mkdtemp()
48+
subprocess.check_call(["git", "clone", git_config["repo"], repo_dir])
49+
50+
_checkout_branch_and_commit(git_config, repo_dir)
51+
52+
ret = {"entry_point": entry_point, "source_dir": source_dir, "dependencies": dependencies}
53+
# check if the cloned repo contains entry point, source directory and dependencies
54+
if source_dir:
55+
if not os.path.isdir(os.path.join(repo_dir, source_dir)):
56+
raise ValueError("Source directory does not exist in the repo.")
57+
if not os.path.isfile(os.path.join(repo_dir, source_dir, entry_point)):
58+
raise ValueError("Entry point does not exist in the repo.")
59+
ret["source_dir"] = os.path.join(repo_dir, source_dir)
60+
else:
61+
if not os.path.isfile(os.path.join(repo_dir, entry_point)):
62+
raise ValueError("Entry point does not exist in the repo.")
63+
ret["entry_point"] = os.path.join(repo_dir, entry_point)
64+
65+
ret["dependencies"] = []
66+
for path in dependencies:
67+
if not os.path.exists(os.path.join(repo_dir, path)):
68+
raise ValueError("Dependency {} does not exist in the repo.".format(path))
69+
ret["dependencies"].append(os.path.join(repo_dir, path))
70+
return ret
71+
72+
73+
def _validate_git_config(git_config):
74+
"""check if a git_config param is valid
75+
76+
Args:
77+
git_config ((dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch``
78+
and ``commit``.
79+
80+
Raises:
81+
ValueError: If:
82+
1. git_config has no key 'repo'
83+
2. git_config['repo'] is in the wrong format.
84+
"""
85+
if "repo" not in git_config:
86+
raise ValueError("Please provide a repo for git_config.")
87+
88+
89+
def _checkout_branch_and_commit(git_config, repo_dir):
90+
"""Checkout the required branch and commit.
91+
92+
Args:
93+
git_config: (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch``
94+
and ``commit``.
95+
repo_dir (str): the directory where the repo is cloned
96+
97+
Raises:
98+
ValueError: If 1. entry point specified does not exist in the repo
99+
2. source dir specified does not exist in the repo
100+
"""
101+
if "branch" in git_config:
102+
subprocess.check_call(args=["git", "checkout", git_config["branch"]], cwd=str(repo_dir))
103+
if "commit" in git_config:
104+
subprocess.check_call(args=["git", "checkout", git_config["commit"]], cwd=str(repo_dir))

tests/integ/test_git.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
import numpy
18+
import tempfile
19+
20+
from tests.integ import lock as lock
21+
from sagemaker.mxnet.estimator import MXNet
22+
from sagemaker.pytorch.estimator import PyTorch
23+
from tests.integ import DATA_DIR, PYTHON_VERSION
24+
25+
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
26+
BRANCH = "test-branch-git-config"
27+
COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a"
28+
29+
# endpoint tests all use the same port, so we use this lock to prevent concurrent execution
30+
LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_git_lock")
31+
32+
33+
def test_git_support_with_pytorch(sagemaker_local_session):
34+
script_path = "mnist.py"
35+
data_path = os.path.join(DATA_DIR, "pytorch_mnist")
36+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
37+
pytorch = PyTorch(
38+
entry_point=script_path,
39+
role="SageMakerRole",
40+
source_dir="pytorch",
41+
framework_version=PyTorch.LATEST_VERSION,
42+
py_version=PYTHON_VERSION,
43+
train_instance_count=1,
44+
train_instance_type="local",
45+
sagemaker_session=sagemaker_local_session,
46+
git_config=git_config,
47+
)
48+
49+
pytorch.fit({"training": "file://" + os.path.join(data_path, "training")})
50+
51+
with lock.lock(LOCK_PATH):
52+
try:
53+
predictor = pytorch.deploy(initial_instance_count=1, instance_type="local")
54+
55+
data = numpy.zeros(shape=(1, 1, 28, 28)).astype(numpy.float32)
56+
result = predictor.predict(data)
57+
assert result is not None
58+
finally:
59+
predictor.delete_endpoint()
60+
61+
62+
def test_git_support_with_mxnet(sagemaker_local_session, mxnet_full_version):
63+
script_path = "mnist.py"
64+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
65+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
66+
dependencies = ["foo/bar.py"]
67+
mx = MXNet(
68+
entry_point=script_path,
69+
role="SageMakerRole",
70+
source_dir="mxnet",
71+
dependencies=dependencies,
72+
framework_version=MXNet.LATEST_VERSION,
73+
py_version=PYTHON_VERSION,
74+
train_instance_count=1,
75+
train_instance_type="local",
76+
sagemaker_session=sagemaker_local_session,
77+
git_config=git_config,
78+
)
79+
80+
mx.fit(
81+
{
82+
"train": "file://" + os.path.join(data_path, "train"),
83+
"test": "file://" + os.path.join(data_path, "test"),
84+
}
85+
)
86+
87+
files = [file for file in os.listdir(mx.source_dir)]
88+
assert "some_file" in files
89+
assert "mnist.py" in files
90+
assert os.path.exists(mx.dependencies[0])
91+
92+
with lock.lock(LOCK_PATH):
93+
try:
94+
predictor = mx.deploy(initial_instance_count=1, instance_type="local")
95+
96+
data = numpy.zeros(shape=(1, 1, 28, 28))
97+
result = predictor.predict(data)
98+
assert result is not None
99+
finally:
100+
predictor.delete_endpoint()

0 commit comments

Comments
 (0)