Skip to content

fix: Set ProcessingStep upload locations deterministically to avoid c… #2790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 94 additions & 56 deletions ci-scripts/queue_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,100 +23,138 @@
).get_caller_identity()["Account"]
bucket_name = "sagemaker-us-west-2-%s" % account

MAX_IN_PROGRESS_BUILDS = 3
INTERVAL_BETWEEN_CONCURRENT_RUNS = 15 # minutes
CLEAN_UP_TICKETS_OLDER_THAN = 8 # hours


def queue_build():
build_id = re.sub("[_/]", "-", os.environ.get("CODEBUILD_BUILD_ID", "CODEBUILD-BUILD-ID"))
source_version = re.sub(
"[_/]",
"-",
os.environ.get("CODEBUILD_SOURCE_VERSION", "CODEBUILD-SOURCE-VERSION"),
)
ticket_number = int(1000 * time.time())
filename = "%s_%s_%s" % (ticket_number, build_id, source_version)

print("Created queue ticket %s" % ticket_number)

_write_ticket(filename)
files = _list_tickets()
_cleanup_tickets_older_than_8_hours(files)
_wait_for_other_builds(files, ticket_number)
_cleanup_tickets_older_than(files)
_wait_for_other_builds(ticket_number)


def _build_info_from_file(file):
filename = file.key.split("/")[1]
filename = file.key.split("/")[2]
ticket_number, build_id, source_version = filename.split("_")
return int(ticket_number), build_id, source_version


def _wait_for_other_builds(files, ticket_number):
newfiles = list(filter(lambda file: not _file_older_than(file), files))
sorted_files = list(sorted(newfiles, key=lambda y: y.key))
def _wait_for_other_builds(ticket_number):
sorted_files = _list_tickets()

print("build queue status:")
print()

for order, file in enumerate(sorted_files):
file_ticket_number, build_id, source_version = _build_info_from_file(file)
print(
"%s -> %s %s, ticket number: %s" % (order, build_id, source_version, file_ticket_number)
"%s -> %s %s, ticket number: %s status: %s"
% (order, build_id, source_version, file_ticket_number, file.key.split("/")[1])
)
print()
build_id = re.sub("[_/]", "-", os.environ.get("CODEBUILD_BUILD_ID", "CODEBUILD-BUILD-ID"))
source_version = re.sub(
"[_/]",
"-",
os.environ.get("CODEBUILD_SOURCE_VERSION", "CODEBUILD-SOURCE-VERSION"),
)
filename = "%s_%s_%s" % (ticket_number, build_id, source_version)
s3_file_obj = _write_ticket(filename, status="waiting")
print("Build %s waiting to be scheduled" % filename)

while True:
_cleanup_tickets_with_terminal_states()
waiting_tickets = _list_tickets("waiting")
if waiting_tickets:
first_waiting_ticket_number, _, _ = _build_info_from_file(_list_tickets("waiting")[0])
else:
first_waiting_ticket_number = ticket_number

if (
len(_list_tickets(status="in-progress")) < 3
and last_in_progress_elapsed_time_check()
and first_waiting_ticket_number == ticket_number
):
# put the build in progress
print("Scheduling build %s for running.." % filename)
s3_file_obj.delete()
_write_ticket(filename, status="in-progress")
break
else:
# wait
time.sleep(30)

for file in sorted_files:
file_ticket_number, build_id, source_version = _build_info_from_file(file)

if file_ticket_number == ticket_number:
def last_in_progress_elapsed_time_check():
in_progress_tickets = _list_tickets("in-progress")
if not in_progress_tickets:
return True
last_in_progress_ticket, _, _ = _build_info_from_file(_list_tickets("in-progress")[-1])
_elapsed_time = int(1000 * time.time()) - last_in_progress_ticket
last_in_progress_elapsed_time = int(_elapsed_time / (1000 * 60)) # in minutes
return last_in_progress_elapsed_time > INTERVAL_BETWEEN_CONCURRENT_RUNS

break
else:
while True:
client = boto3.client("codebuild")
response = client.batch_get_builds(ids=[build_id])
build_status = response["builds"][0]["buildStatus"]

if build_status == "IN_PROGRESS":
print(
"waiting on build %s %s %s" % (build_id, source_version, file_ticket_number)
)
time.sleep(30)
else:
print("build %s finished, deleting lock" % build_id)
file.delete()
break


def _cleanup_tickets_older_than_8_hours(files):

def _cleanup_tickets_with_terminal_states():
files = _list_tickets()
build_ids = []
for file in files:
_, build_id, _ = _build_info_from_file(file)
build_ids.append(build_id)

client = boto3.client("codebuild")
response = client.batch_get_builds(ids=build_ids)

for file, build_details in zip(files, response["builds"]):
_, _build_id_from_file, _ = _build_info_from_file(file)
build_status = build_details["buildStatus"]

if build_status != "IN_PROGRESS" and _build_id_from_file == build_details["id"]:
print(
"Build %s in terminal state: %s, deleting lock"
% (_build_id_from_file, build_status)
)
file.delete()


def _cleanup_tickets_older_than(files):
oldfiles = list(filter(_file_older_than, files))
for file in oldfiles:
print("object %s older than 8 hours. Deleting" % file.key)
file.delete()
return files


def _list_tickets():
def _list_tickets(status=None):
s3 = boto3.resource("s3")
bucket = s3.Bucket(bucket_name)
objects = [file for file in bucket.objects.filter(Prefix="ci-lock/")]
files = list(filter(lambda x: x != "ci-lock/", objects))
return files
prefix = "ci-integ-queue/{}/".format(status) if status else "ci-integ-queue/"
objects = [file for file in bucket.objects.filter(Prefix=prefix)]
files = list(filter(lambda x: x != prefix, objects))
sorted_files = list(sorted(files, key=lambda y: y.key))
return sorted_files


def _file_older_than(file):
timelimit = 1000 * 60 * 60 * 8

timelimit = 1000 * 60 * 60 * CLEAN_UP_TICKETS_OLDER_THAN
file_ticket_number, build_id, source_version = _build_info_from_file(file)
return int(1000 * time.time()) - file_ticket_number > timelimit

return int(time.time()) - file_ticket_number > timelimit


def _write_ticket(ticket_number):

if not os.path.exists("ci-lock"):
os.mkdir("ci-lock")
def _write_ticket(filename, status="waiting"):
file_path = "ci-integ-queue/{}".format(status)
if not os.path.exists(file_path):
os.makedirs(file_path)

filename = "ci-lock/" + ticket_number
with open(filename, "w") as file:
file.write(ticket_number)
boto3.Session().resource("s3").Object(bucket_name, filename).upload_file(filename)
file_full_path = file_path + "/" + filename
with open(file_full_path, "w") as file:
file.write(filename)
s3_file_obj = boto3.Session().resource("s3").Object(bucket_name, file_full_path)
s3_file_obj.upload_file(file_full_path)
print("Build %s is now in state %s" % (filename, status))
return s3_file_obj


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def serve(self, model_dir, environment):
script_dir = environment[sagemaker.estimator.DIR_PARAM_NAME.upper()]
parsed_uri = urlparse(script_dir)
if parsed_uri.scheme == "file":
volumes.append(_Volume(parsed_uri.path, "/opt/ml/code"))
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
volumes.append(_Volume(host_dir, "/opt/ml/code"))
# Update path to mount location
environment = environment.copy()
environment[sagemaker.estimator.DIR_PARAM_NAME.upper()] = "/opt/ml/code"
Expand Down Expand Up @@ -495,7 +496,8 @@ def _prepare_training_volumes(
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
parsed_uri = urlparse(training_dir)
if parsed_uri.scheme == "file":
volumes.append(_Volume(parsed_uri.path, "/opt/ml/code"))
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
volumes.append(_Volume(host_dir, "/opt/ml/code"))
# Also mount a directory that all the containers can access.
volumes.append(_Volume(shared_dir, "/opt/ml/shared"))

Expand All @@ -504,7 +506,8 @@ def _prepare_training_volumes(
parsed_uri.scheme == "file"
and sagemaker.model.SAGEMAKER_OUTPUT_LOCATION in hyperparameters
):
intermediate_dir = os.path.join(parsed_uri.path, "output", "intermediate")
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
intermediate_dir = os.path.join(dir_path, "output", "intermediate")
if not os.path.exists(intermediate_dir):
os.makedirs(intermediate_dir)
volumes.append(_Volume(intermediate_dir, "/opt/ml/output/intermediate"))
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
"""
parsed_uri = urlparse(destination)
if parsed_uri.scheme == "file":
recursive_copy(source, parsed_uri.path)
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
recursive_copy(source, dir_path)
final_uri = destination
elif parsed_uri.scheme == "s3":
bucket = parsed_uri.netloc
Expand Down Expand Up @@ -116,9 +117,8 @@ def get_child_process_ids(pid):
(List[int]): Child process ids
"""
cmd = f"pgrep -P {pid}".split()
output, err = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
).communicate()
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
if err:
return []
pids = [int(pid) for pid in output.decode("utf-8").split()]
Expand Down
24 changes: 11 additions & 13 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3556,19 +3556,17 @@ def endpoint_from_production_variants(
Returns:
str: The name of the created ``Endpoint``.
"""
if not _deployment_entity_exists(
lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)
):
config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
tags = _append_project_tags(tags)
if tags:
config_options["Tags"] = tags
if kms_key:
config_options["KmsKeyId"] = kms_key
if data_capture_config_dict is not None:
config_options["DataCaptureConfig"] = data_capture_config_dict

self.sagemaker_client.create_endpoint_config(**config_options)
config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
tags = _append_project_tags(tags)
if tags:
config_options["Tags"] = tags
if kms_key:
config_options["KmsKeyId"] = kms_key
if data_capture_config_dict is not None:
config_options["DataCaptureConfig"] = data_capture_config_dict

self.sagemaker_client.create_endpoint_config(**config_options)

return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait)

def expand_role(self, role):
Expand Down
32 changes: 32 additions & 0 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from __future__ import absolute_import

import abc
import warnings
from enum import Enum
from typing import Dict, List, Union
from urllib.parse import urlparse

import attr

Expand Down Expand Up @@ -270,6 +272,16 @@ def __init__(
)
self.cache_config = cache_config

if self.cache_config is not None and not self.estimator.disable_profiler:
msg = (
"Profiling is enabled on the provided estimator. "
"The default profiler rule includes a timestamp "
"which will change each time the pipeline is "
"upserted, causing cache misses. If profiling "
"is not needed, set disable_profiler to True on the estimator."
)
warnings.warn(msg)

@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to call `create_training_job`.
Expand Down Expand Up @@ -498,6 +510,7 @@ def __init__(
self.job_arguments = job_arguments
self.code = code
self.property_files = property_files
self.job_name = None

# Examine why run method in sagemaker.processing.Processor mutates the processor instance
# by setting the instance's arguments attribute. Refactor Processor.run, if possible.
Expand All @@ -508,6 +521,17 @@ def __init__(
)
self.cache_config = cache_config

if code:
code_url = urlparse(code)
if code_url.scheme == "" or code_url.scheme == "file":
# By default, Processor will upload the local code to an S3 path
# containing a timestamp. This causes cache misses whenever a
# pipeline is updated, even if the underlying script hasn't changed.
# To avoid this, hash the contents of the script and include it
# in the job_name passed to the Processor, which will be used
# instead of the timestamped path.
self.job_name = self._generate_code_upload_path()

@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to call `create_processing_job`.
Expand All @@ -516,6 +540,7 @@ def arguments(self) -> RequestType:
ProcessingJobName and ExperimentConfig cannot be included in the arguments.
"""
normalized_inputs, normalized_outputs = self.processor._normalize_args(
job_name=self.job_name,
arguments=self.job_arguments,
inputs=self.inputs,
outputs=self.outputs,
Expand Down Expand Up @@ -546,6 +571,13 @@ def to_request(self) -> RequestType:
]
return request_dict

def _generate_code_upload_path(self) -> str:
"""Generate an upload path for local processing scripts based on its contents"""
from sagemaker.workflow.utilities import hash_file

code_hash = hash_file(self.code)
return f"{self.name}-{code_hash}"[:1024]


class TuningStep(ConfigurableRetryStep):
"""Tuning step for workflow."""
Expand Down
21 changes: 21 additions & 0 deletions src/sagemaker/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

from typing import List, Sequence, Union
import hashlib

from sagemaker.workflow.entities import (
Entity,
Expand All @@ -37,3 +38,23 @@ def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[R
elif isinstance(entity, StepCollection):
request_dicts.extend(entity.request_dicts())
return request_dicts


def hash_file(path: str) -> str:
"""Get the MD5 hash of a file.

Args:
path (str): The local path for the file.
Returns:
str: The MD5 hash of the file.
"""
BUF_SIZE = 65536 # read in 64KiB chunks
md5 = hashlib.md5()
with open(path, "rb") as f:
while True:
data = f.read(BUF_SIZE)
if not data:
break
md5.update(data)

return md5.hexdigest()
Loading