Skip to content

Commit 6d42cc8

Browse files
zhangluomaAwesomeLeo Zhang
and
AwesomeLeo Zhang
authored
feat: Support uncompressed model upload (#3862)
Co-authored-by: AwesomeLeo Zhang <[email protected]>
1 parent 6503367 commit 6d42cc8

File tree

5 files changed

+62
-4
lines changed

5 files changed

+62
-4
lines changed

src/sagemaker/estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def __init__(
173173
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
174174
container_entry_point: Optional[List[str]] = None,
175175
container_arguments: Optional[List[str]] = None,
176+
disable_output_compression: bool = False,
176177
**kwargs,
177178
):
178179
"""Initialize an ``EstimatorBase`` instance.
@@ -531,6 +532,8 @@ def __init__(
531532
the default train processing instructions.
532533
container_arguments (List[str]): Optional. The arguments for a container used to run
533534
a training job.
535+
disable_output_compression (bool): Optional. When set to true, Model is uploaded
536+
to Amazon S3 without compression after training finishes.
534537
"""
535538
instance_count = renamed_kwargs(
536539
"train_instance_count", "instance_count", instance_count, kwargs
@@ -712,7 +715,7 @@ def __init__(
712715
self.profiler_rule_configs = None
713716
self.profiler_rules = None
714717
self.debugger_rules = None
715-
718+
self.disable_output_compression = disable_output_compression
716719
validate_source_code_input_against_pipeline_variables(
717720
entry_point=entry_point,
718721
source_dir=source_dir,
@@ -2507,6 +2510,7 @@ def __init__(
25072510
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
25082511
container_entry_point: Optional[List[str]] = None,
25092512
container_arguments: Optional[List[str]] = None,
2513+
disable_output_compression: bool = False,
25102514
**kwargs,
25112515
):
25122516
"""Initialize an ``Estimator`` instance.
@@ -2864,6 +2868,8 @@ def __init__(
28642868
the default train processing instructions.
28652869
container_arguments (List[str]): Optional. The arguments for a container used to run
28662870
a training job.
2871+
disable_output_compression (bool): Optional. When set to true, Model is uploaded
2872+
to Amazon S3 without compression after training finishes.
28672873
"""
28682874
self.image_uri = image_uri
28692875
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
@@ -2913,6 +2919,7 @@ def __init__(
29132919
training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # noqa: E501 # pylint: disable=line-too-long
29142920
container_entry_point=container_entry_point,
29152921
container_arguments=container_arguments,
2922+
disable_output_compression=disable_output_compression,
29162923
**kwargs,
29172924
)
29182925

src/sagemaker/job.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
7171
if (expand_role and not is_pipeline_variable(estimator.role))
7272
else estimator.role
7373
)
74-
output_config = _Job._prepare_output_config(estimator.output_path, estimator.output_kms_key)
74+
output_config = _Job._prepare_output_config(
75+
estimator.output_path,
76+
estimator.output_kms_key,
77+
disable_output_compression=estimator.disable_output_compression,
78+
)
7579
resource_config = _Job._prepare_resource_config(
7680
estimator.instance_count,
7781
estimator.instance_type,
@@ -273,11 +277,13 @@ def _format_record_set_list_input(inputs):
273277
return input_dict
274278

275279
@staticmethod
276-
def _prepare_output_config(s3_path, kms_key_id):
280+
def _prepare_output_config(s3_path, kms_key_id, disable_output_compression=False):
277281
"""Placeholder docstring"""
278282
config = {"S3OutputPath": s3_path}
279283
if kms_key_id is not None:
280284
config["KmsKeyId"] = kms_key_id
285+
if disable_output_compression:
286+
config["CompressionType"] = "NONE"
281287
return config
282288

283289
@staticmethod

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
522522
and cut a ticket sev-3 to JumpStart team: AWS > SageMaker > JumpStart"""
523523

524524
init_args_to_skip: Set[str] = set(
525-
["container_entry_point", "container_arguments", "kwargs"]
525+
["container_entry_point", "container_arguments", "disable_output_compression", "kwargs"]
526526
)
527527
fit_args_to_skip: Set[str] = set()
528528
deploy_args_to_skip: Set[str] = set(["kwargs"])

tests/unit/test_estimator.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,6 +1819,33 @@ def test_get_instance_type_gpu(sagemaker_session):
18191819
assert "ml.p3.16xlarge" == estimator._get_instance_type()
18201820

18211821

1822+
def test_estimator_with_output_compression_disabled(sagemaker_session):
1823+
estimator = Estimator(
1824+
image_uri="some-image",
1825+
role="some_image",
1826+
instance_count=INSTANCE_COUNT,
1827+
instance_type=INSTANCE_TYPE,
1828+
sagemaker_session=sagemaker_session,
1829+
base_job_name="base_job_name",
1830+
disable_output_compression=True,
1831+
)
1832+
1833+
assert estimator.disable_output_compression
1834+
1835+
1836+
def test_estimator_with_output_compression_as_default(sagemaker_session):
1837+
estimator = Estimator(
1838+
image_uri="some-image",
1839+
role="some_image",
1840+
instance_count=INSTANCE_COUNT,
1841+
instance_type=INSTANCE_TYPE,
1842+
sagemaker_session=sagemaker_session,
1843+
base_job_name="base_job_name",
1844+
)
1845+
1846+
assert not estimator.disable_output_compression
1847+
1848+
18221849
def test_get_instance_type_cpu(sagemaker_session):
18231850
estimator = Estimator(
18241851
image_uri="some-image",

tests/unit/test_job.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
REGION = "us-west-2"
3939
IMAGE_NAME = "fakeimage"
4040
SCRIPT_NAME = "script.py"
41+
NONE_COMPRESSION_TYPE = "NONE"
4142
JOB_NAME = "fakejob"
4243
VOLUME_KMS_KEY = "volkmskey"
4344
MODEL_CHANNEL_NAME = "testModelChannel"
@@ -144,6 +145,23 @@ def test_load_config(estimator):
144145
assert config["role"] == ROLE
145146
assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH
146147
assert "KmsKeyId" not in config["output_config"]
148+
assert "CompressionType" not in config["output_config"]
149+
assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT
150+
assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE
151+
assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE
152+
assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME
153+
154+
155+
def test_load_config_with_output_compression_disabled(estimator):
156+
inputs = TrainingInput(BUCKET_NAME)
157+
estimator.disable_output_compression = True
158+
config = _Job._load_config(inputs, estimator)
159+
160+
assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == BUCKET_NAME
161+
assert config["role"] == ROLE
162+
assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH
163+
assert "KmsKeyId" not in config["output_config"]
164+
assert config["output_config"]["CompressionType"] == NONE_COMPRESSION_TYPE
147165
assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT
148166
assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE
149167
assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE

0 commit comments

Comments
 (0)