Skip to content

Change: Allow extra_args to be passed to uploader #4338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/sagemaker/experiments/_helper.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/bot run all

Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,15 @@ def __init__(
self.artifact_prefix = artifact_prefix
self._s3_client = self.sagemaker_session.boto_session.client("s3")

def upload_artifact(self, file_path):
def upload_artifact(self, file_path, extra_args=None):
"""Upload an artifact file to S3.

Args:
file_path (str): the file path of the artifact
extra_args (dict): Optional extra arguments that may be passed to the upload operation.
Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
ExtraArgs parameter documentation here:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter

Returns:
(str, str): The s3 URI of the uploaded file and the etag of the file.
Expand Down Expand Up @@ -91,7 +95,12 @@ def upload_artifact(self, file_path):
artifact_s3_key = "{}/{}/{}".format(
self.artifact_prefix, self.trial_component_name, artifact_name
)
self._s3_client.upload_file(file_path, self.artifact_bucket, artifact_s3_key)
self._s3_client.upload_file(
file_path,
self.artifact_bucket,
artifact_s3_key,
ExtraArgs=extra_args,
)
etag = self._try_get_etag(artifact_s3_key)
return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag

Expand Down
9 changes: 7 additions & 2 deletions src/sagemaker/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,8 @@ def log_file(
file_path: str,
name: Optional[str] = None,
media_type: Optional[str] = None,
is_output: bool = True,
is_output: Optional[bool] = True,
extra_args: Optional[dict] = None,
):
"""Upload a file to s3 and store it as an input/output artifact in this run.

Expand All @@ -521,11 +522,15 @@ def log_file(
is_output (bool): Determines direction of association to the
run. Defaults to True (output artifact).
If set to False then represented as input association.
extra_args (dict): Optional extra arguments that may be passed to the upload operation.
Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
ExtraArgs parameter documentation here:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
"""
self._verify_trial_component_artifacts_length(is_output)
media_type = media_type or guess_media_type(file_path)
name = name or resolve_artifact_name(file_path)
s3_uri, _ = self._artifact_uploader.upload_artifact(file_path)
s3_uri, _ = self._artifact_uploader.upload_artifact(file_path, extra_args=extra_args)
if is_output:
self._trial_component.output_artifacts[name] = TrialComponentArtifact(
value=s3_uri, media_type=media_type
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/experiments/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_artifact_uploader_upload_artifact(tempdir, artifact_uploader):
)

artifact_uploader._s3_client.upload_file.assert_called_with(
path, artifact_uploader.artifact_bucket, expected_key
path, artifact_uploader.artifact_bucket, expected_key, ExtraArgs=None
)

expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key)
Expand Down
16 changes: 10 additions & 6 deletions tests/unit/sagemaker/experiments/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,11 @@ def test_log_output_artifact(run_obj):
run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
with run_obj:
run_obj.log_file("foo.txt", "name", "whizz/bang")
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
assert "whizz/bang" == run_obj._trial_component.output_artifacts["name"].media_type

run_obj.log_file("foo.txt")
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
assert "foo.txt" in run_obj._trial_component.output_artifacts
assert "text/plain" == run_obj._trial_component.output_artifacts["foo.txt"].media_type

Expand All @@ -611,11 +611,11 @@ def test_log_input_artifact(run_obj):
run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
with run_obj:
run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
assert "whizz/bang" == run_obj._trial_component.input_artifacts["name"].media_type

run_obj.log_file("foo.txt", is_output=False)
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt", extra_args=None)
assert "foo.txt" in run_obj._trial_component.input_artifacts
assert "text/plain" == run_obj._trial_component.input_artifacts["foo.txt"].media_type

Expand Down Expand Up @@ -655,7 +655,9 @@ def test_log_multiple_input_artifacts(run_obj):
run_obj.log_file(
file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False
)
run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path)
run_obj._artifact_uploader.upload_artifact.assert_called_with(
file_path, extra_args=None
)

run_obj._artifact_uploader.upload_artifact.return_value = (
"s3uri_value",
Expand All @@ -680,7 +682,9 @@ def test_log_multiple_output_artifacts(run_obj):
"etag_value" + str(index),
)
run_obj.log_file(file_path, "name" + str(index), "whizz/bang" + str(index))
run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path)
run_obj._artifact_uploader.upload_artifact.assert_called_with(
file_path, extra_args=None
)

run_obj._artifact_uploader.upload_artifact.return_value = (
"s3uri_value",
Expand Down