Skip to content

Commit 13ca522

Browse files
authored
Merge branch 'master' into fix-pipeline-caching
2 parents df2143f + 34dd43a commit 13ca522

File tree

5 files changed

+157
-78
lines changed

5 files changed

+157
-78
lines changed

ci-scripts/queue_build.py

+94-56
Original file line numberDiff line numberDiff line change
@@ -23,100 +23,138 @@
2323
).get_caller_identity()["Account"]
2424
bucket_name = "sagemaker-us-west-2-%s" % account
2525

26+
MAX_IN_PROGRESS_BUILDS = 3
27+
INTERVAL_BETWEEN_CONCURRENT_RUNS = 15 # minutes
28+
CLEAN_UP_TICKETS_OLDER_THAN = 8 # hours
29+
2630

2731
def queue_build():
28-
build_id = re.sub("[_/]", "-", os.environ.get("CODEBUILD_BUILD_ID", "CODEBUILD-BUILD-ID"))
29-
source_version = re.sub(
30-
"[_/]",
31-
"-",
32-
os.environ.get("CODEBUILD_SOURCE_VERSION", "CODEBUILD-SOURCE-VERSION"),
33-
)
3432
ticket_number = int(1000 * time.time())
35-
filename = "%s_%s_%s" % (ticket_number, build_id, source_version)
36-
37-
print("Created queue ticket %s" % ticket_number)
38-
39-
_write_ticket(filename)
4033
files = _list_tickets()
41-
_cleanup_tickets_older_than_8_hours(files)
42-
_wait_for_other_builds(files, ticket_number)
34+
_cleanup_tickets_older_than(files)
35+
_wait_for_other_builds(ticket_number)
4336

4437

4538
def _build_info_from_file(file):
46-
filename = file.key.split("/")[1]
39+
filename = file.key.split("/")[2]
4740
ticket_number, build_id, source_version = filename.split("_")
4841
return int(ticket_number), build_id, source_version
4942

5043

51-
def _wait_for_other_builds(files, ticket_number):
52-
newfiles = list(filter(lambda file: not _file_older_than(file), files))
53-
sorted_files = list(sorted(newfiles, key=lambda y: y.key))
44+
def _wait_for_other_builds(ticket_number):
45+
sorted_files = _list_tickets()
5446

5547
print("build queue status:")
5648
print()
5749

5850
for order, file in enumerate(sorted_files):
5951
file_ticket_number, build_id, source_version = _build_info_from_file(file)
6052
print(
61-
"%s -> %s %s, ticket number: %s" % (order, build_id, source_version, file_ticket_number)
53+
"%s -> %s %s, ticket number: %s status: %s"
54+
% (order, build_id, source_version, file_ticket_number, file.key.split("/")[1])
6255
)
56+
print()
57+
build_id = re.sub("[_/]", "-", os.environ.get("CODEBUILD_BUILD_ID", "CODEBUILD-BUILD-ID"))
58+
source_version = re.sub(
59+
"[_/]",
60+
"-",
61+
os.environ.get("CODEBUILD_SOURCE_VERSION", "CODEBUILD-SOURCE-VERSION"),
62+
)
63+
filename = "%s_%s_%s" % (ticket_number, build_id, source_version)
64+
s3_file_obj = _write_ticket(filename, status="waiting")
65+
print("Build %s waiting to be scheduled" % filename)
66+
67+
while True:
68+
_cleanup_tickets_with_terminal_states()
69+
waiting_tickets = _list_tickets("waiting")
70+
if waiting_tickets:
71+
first_waiting_ticket_number, _, _ = _build_info_from_file(_list_tickets("waiting")[0])
72+
else:
73+
first_waiting_ticket_number = ticket_number
74+
75+
if (
76+
len(_list_tickets(status="in-progress")) < 3
77+
and last_in_progress_elapsed_time_check()
78+
and first_waiting_ticket_number == ticket_number
79+
):
80+
# put the build in progress
81+
print("Scheduling build %s for running.." % filename)
82+
s3_file_obj.delete()
83+
_write_ticket(filename, status="in-progress")
84+
break
85+
else:
86+
# wait
87+
time.sleep(30)
6388

64-
for file in sorted_files:
65-
file_ticket_number, build_id, source_version = _build_info_from_file(file)
6689

67-
if file_ticket_number == ticket_number:
90+
def last_in_progress_elapsed_time_check():
91+
in_progress_tickets = _list_tickets("in-progress")
92+
if not in_progress_tickets:
93+
return True
94+
last_in_progress_ticket, _, _ = _build_info_from_file(_list_tickets("in-progress")[-1])
95+
_elapsed_time = int(1000 * time.time()) - last_in_progress_ticket
96+
last_in_progress_elapsed_time = int(_elapsed_time / (1000 * 60)) # in minutes
97+
return last_in_progress_elapsed_time > INTERVAL_BETWEEN_CONCURRENT_RUNS
6898

69-
break
70-
else:
71-
while True:
72-
client = boto3.client("codebuild")
73-
response = client.batch_get_builds(ids=[build_id])
74-
build_status = response["builds"][0]["buildStatus"]
75-
76-
if build_status == "IN_PROGRESS":
77-
print(
78-
"waiting on build %s %s %s" % (build_id, source_version, file_ticket_number)
79-
)
80-
time.sleep(30)
81-
else:
82-
print("build %s finished, deleting lock" % build_id)
83-
file.delete()
84-
break
85-
86-
87-
def _cleanup_tickets_older_than_8_hours(files):
99+
100+
def _cleanup_tickets_with_terminal_states():
101+
files = _list_tickets()
102+
build_ids = []
103+
for file in files:
104+
_, build_id, _ = _build_info_from_file(file)
105+
build_ids.append(build_id)
106+
107+
client = boto3.client("codebuild")
108+
response = client.batch_get_builds(ids=build_ids)
109+
110+
for file, build_details in zip(files, response["builds"]):
111+
_, _build_id_from_file, _ = _build_info_from_file(file)
112+
build_status = build_details["buildStatus"]
113+
114+
if build_status != "IN_PROGRESS" and _build_id_from_file == build_details["id"]:
115+
print(
116+
"Build %s in terminal state: %s, deleting lock"
117+
% (_build_id_from_file, build_status)
118+
)
119+
file.delete()
120+
121+
122+
def _cleanup_tickets_older_than(files):
88123
oldfiles = list(filter(_file_older_than, files))
89124
for file in oldfiles:
90125
print("object %s older than 8 hours. Deleting" % file.key)
91126
file.delete()
92127
return files
93128

94129

95-
def _list_tickets():
130+
def _list_tickets(status=None):
96131
s3 = boto3.resource("s3")
97132
bucket = s3.Bucket(bucket_name)
98-
objects = [file for file in bucket.objects.filter(Prefix="ci-lock/")]
99-
files = list(filter(lambda x: x != "ci-lock/", objects))
100-
return files
133+
prefix = "ci-integ-queue/{}/".format(status) if status else "ci-integ-queue/"
134+
objects = [file for file in bucket.objects.filter(Prefix=prefix)]
135+
files = list(filter(lambda x: x != prefix, objects))
136+
sorted_files = list(sorted(files, key=lambda y: y.key))
137+
return sorted_files
101138

102139

103140
def _file_older_than(file):
104-
timelimit = 1000 * 60 * 60 * 8
105-
141+
timelimit = 1000 * 60 * 60 * CLEAN_UP_TICKETS_OLDER_THAN
106142
file_ticket_number, build_id, source_version = _build_info_from_file(file)
143+
return int(1000 * time.time()) - file_ticket_number > timelimit
107144

108-
return int(time.time()) - file_ticket_number > timelimit
109-
110-
111-
def _write_ticket(ticket_number):
112145

113-
if not os.path.exists("ci-lock"):
114-
os.mkdir("ci-lock")
146+
def _write_ticket(filename, status="waiting"):
147+
file_path = "ci-integ-queue/{}".format(status)
148+
if not os.path.exists(file_path):
149+
os.makedirs(file_path)
115150

116-
filename = "ci-lock/" + ticket_number
117-
with open(filename, "w") as file:
118-
file.write(ticket_number)
119-
boto3.Session().resource("s3").Object(bucket_name, filename).upload_file(filename)
151+
file_full_path = file_path + "/" + filename
152+
with open(file_full_path, "w") as file:
153+
file.write(filename)
154+
s3_file_obj = boto3.Session().resource("s3").Object(bucket_name, file_full_path)
155+
s3_file_obj.upload_file(file_full_path)
156+
print("Build %s is now in state %s" % (filename, status))
157+
return s3_file_obj
120158

121159

122160
if __name__ == "__main__":

src/sagemaker/local/image.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def serve(self, model_dir, environment):
277277
script_dir = environment[sagemaker.estimator.DIR_PARAM_NAME.upper()]
278278
parsed_uri = urlparse(script_dir)
279279
if parsed_uri.scheme == "file":
280-
volumes.append(_Volume(parsed_uri.path, "/opt/ml/code"))
280+
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
281+
volumes.append(_Volume(host_dir, "/opt/ml/code"))
281282
# Update path to mount location
282283
environment = environment.copy()
283284
environment[sagemaker.estimator.DIR_PARAM_NAME.upper()] = "/opt/ml/code"
@@ -495,7 +496,8 @@ def _prepare_training_volumes(
495496
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
496497
parsed_uri = urlparse(training_dir)
497498
if parsed_uri.scheme == "file":
498-
volumes.append(_Volume(parsed_uri.path, "/opt/ml/code"))
499+
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
500+
volumes.append(_Volume(host_dir, "/opt/ml/code"))
499501
# Also mount a directory that all the containers can access.
500502
volumes.append(_Volume(shared_dir, "/opt/ml/shared"))
501503

@@ -504,7 +506,8 @@ def _prepare_training_volumes(
504506
parsed_uri.scheme == "file"
505507
and sagemaker.model.SAGEMAKER_OUTPUT_LOCATION in hyperparameters
506508
):
507-
intermediate_dir = os.path.join(parsed_uri.path, "output", "intermediate")
509+
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
510+
intermediate_dir = os.path.join(dir_path, "output", "intermediate")
508511
if not os.path.exists(intermediate_dir):
509512
os.makedirs(intermediate_dir)
510513
volumes.append(_Volume(intermediate_dir, "/opt/ml/output/intermediate"))

src/sagemaker/local/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
6464
"""
6565
parsed_uri = urlparse(destination)
6666
if parsed_uri.scheme == "file":
67-
recursive_copy(source, parsed_uri.path)
67+
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
68+
recursive_copy(source, dir_path)
6869
final_uri = destination
6970
elif parsed_uri.scheme == "s3":
7071
bucket = parsed_uri.netloc
@@ -116,9 +117,8 @@ def get_child_process_ids(pid):
116117
(List[int]): Child process ids
117118
"""
118119
cmd = f"pgrep -P {pid}".split()
119-
output, err = subprocess.Popen(
120-
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
121-
).communicate()
120+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
121+
output, err = process.communicate()
122122
if err:
123123
return []
124124
pids = [int(pid) for pid in output.decode("utf-8").split()]

src/sagemaker/session.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -3556,19 +3556,17 @@ def endpoint_from_production_variants(
35563556
Returns:
35573557
str: The name of the created ``Endpoint``.
35583558
"""
3559-
if not _deployment_entity_exists(
3560-
lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)
3561-
):
3562-
config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
3563-
tags = _append_project_tags(tags)
3564-
if tags:
3565-
config_options["Tags"] = tags
3566-
if kms_key:
3567-
config_options["KmsKeyId"] = kms_key
3568-
if data_capture_config_dict is not None:
3569-
config_options["DataCaptureConfig"] = data_capture_config_dict
3570-
3571-
self.sagemaker_client.create_endpoint_config(**config_options)
3559+
config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
3560+
tags = _append_project_tags(tags)
3561+
if tags:
3562+
config_options["Tags"] = tags
3563+
if kms_key:
3564+
config_options["KmsKeyId"] = kms_key
3565+
if data_capture_config_dict is not None:
3566+
config_options["DataCaptureConfig"] = data_capture_config_dict
3567+
3568+
self.sagemaker_client.create_endpoint_config(**config_options)
3569+
35723570
return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait)
35733571

35743572
def expand_role(self, role):

tests/unit/test_local_utils.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,31 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
1516
import pytest
1617
from mock import patch, Mock
1718

1819
import sagemaker.local.utils
1920

2021

22+
@patch("sagemaker.local.utils.os.path")
23+
@patch("sagemaker.local.utils.os")
24+
def test_copy_directory_structure(m_os, m_os_path):
25+
m_os_path.exists.return_value = False
26+
sagemaker.local.utils.copy_directory_structure("/tmp/", "code/")
27+
m_os.makedirs.assert_called_with("/tmp/", "code/")
28+
29+
2130
@patch("shutil.rmtree", Mock())
2231
@patch("sagemaker.local.utils.recursive_copy")
2332
def test_move_to_destination_local(recursive_copy):
2433
# local files will just be recursively copied
25-
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir/", "job", None)
26-
recursive_copy.assert_called_with("/tmp/data", "/target/dir/")
34+
# given absolute path
35+
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir", "job", None)
36+
recursive_copy.assert_called_with("/tmp/data", "/target/dir")
37+
# given relative path
38+
sagemaker.local.utils.move_to_destination("/tmp/data", "file://root/target/dir", "job", None)
39+
recursive_copy.assert_called_with("/tmp/data", os.path.abspath("./root/target/dir"))
2740

2841

2942
@patch("shutil.rmtree", Mock())
@@ -52,3 +65,30 @@ def test_move_to_destination_s3(recursive_copy):
5265
def test_move_to_destination_illegal_destination():
5366
with pytest.raises(ValueError):
5467
sagemaker.local.utils.move_to_destination("/tmp/data", "ftp://ftp/in/2018", "job", None)
68+
69+
70+
@patch("sagemaker.local.utils.os.path")
71+
@patch("sagemaker.local.utils.copy_tree")
72+
def test_recursive_copy(copy_tree, m_os_path):
73+
m_os_path.isdir.return_value = True
74+
sagemaker.local.utils.recursive_copy("source", "destination")
75+
copy_tree.assert_called_with("source", "destination")
76+
77+
78+
@patch("sagemaker.local.utils.os")
79+
@patch("sagemaker.local.utils.get_child_process_ids")
80+
def test_kill_child_processes(m_get_child_process_ids, m_os):
81+
m_get_child_process_ids.return_value = ["child_pids"]
82+
sagemaker.local.utils.kill_child_processes("pid")
83+
m_os.kill.assert_called_with("child_pids", 15)
84+
85+
86+
@patch("sagemaker.local.utils.subprocess")
87+
def test_get_child_process_ids(m_subprocess):
88+
cmd = "pgrep -P pid".split()
89+
process_mock = Mock()
90+
attrs = {"communicate.return_value": (b"\n", False), "returncode": 0}
91+
process_mock.configure_mock(**attrs)
92+
m_subprocess.Popen.return_value = process_mock
93+
sagemaker.local.utils.get_child_process_ids("pid")
94+
m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE)

0 commit comments

Comments
 (0)