Skip to content

Commit e9abb1d

Browse files
authored
Merge branch 'master' into tgi131
2 parents 5e72c7e + 2432b26 commit e9abb1d

File tree

10 files changed

+178
-55
lines changed

10 files changed

+178
-55
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def get_deploy_kwargs(
336336
tolerate_vulnerable_model=tolerate_vulnerable_model,
337337
tolerate_deprecated_model=tolerate_deprecated_model,
338338
training_instance_type=training_instance_type,
339+
disable_instance_type_logging=True,
339340
)
340341

341342
estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(

src/sagemaker/jumpstart/factory/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def _add_vulnerable_and_deprecated_status_to_kwargs(
171171
return kwargs
172172

173173

174-
def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
174+
def _add_instance_type_to_kwargs(
175+
kwargs: JumpStartModelInitKwargs, disable_instance_type_logging: bool = False
176+
) -> JumpStartModelInitKwargs:
175177
"""Sets instance type based on default or override, returns full kwargs."""
176178

177179
orig_instance_type = kwargs.instance_type
@@ -187,7 +189,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM
187189
training_instance_type=kwargs.training_instance_type,
188190
)
189191

190-
if orig_instance_type is None:
192+
if not disable_instance_type_logging and orig_instance_type is None:
191193
JUMPSTART_LOGGER.info(
192194
"No instance type selected for inference hosting endpoint. Defaulting to %s.",
193195
kwargs.instance_type,
@@ -551,9 +553,7 @@ def get_deploy_kwargs(
551553

552554
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
553555

554-
deploy_kwargs = _add_instance_type_to_kwargs(
555-
kwargs=deploy_kwargs,
556-
)
556+
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
557557

558558
deploy_kwargs.initial_instance_count = initial_instance_count or 1
559559

@@ -677,6 +677,7 @@ def get_init_kwargs(
677677
git_config: Optional[Dict[str, str]] = None,
678678
model_package_arn: Optional[str] = None,
679679
training_instance_type: Optional[str] = None,
680+
disable_instance_type_logging: bool = False,
680681
resources: Optional[ResourceRequirements] = None,
681682
) -> JumpStartModelInitKwargs:
682683
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
@@ -720,7 +721,7 @@ def get_init_kwargs(
720721
model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
721722

722723
model_init_kwargs = _add_instance_type_to_kwargs(
723-
kwargs=model_init_kwargs,
724+
kwargs=model_init_kwargs, disable_instance_type_logging=disable_instance_type_logging
724725
)
725726

726727
model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs)

src/sagemaker/remote_function/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def wait(self, timeout: int = None):
891891
"""
892892

893893
self._last_describe_response = _logs_for_job(
894-
boto_session=self.sagemaker_session.boto_session,
894+
sagemaker_session=self.sagemaker_session,
895895
job_name=self.job_name,
896896
wait=True,
897897
timeout=timeout,

src/sagemaker/session.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
472472

473473
# Initialize the variables used to loop through the contents of the S3 bucket.
474474
keys = []
475+
directories = []
475476
next_token = ""
476477
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
477478

@@ -490,20 +491,26 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
490491
return []
491492
# For each object, save its key or directory.
492493
for s3_object in contents:
493-
key = s3_object.get("Key")
494-
keys.append(key)
494+
key: str = s3_object.get("Key")
495+
obj_size = s3_object.get("Size")
496+
if key.endswith("/") and int(obj_size) == 0:
497+
directories.append(os.path.join(path, key))
498+
else:
499+
keys.append(key)
495500
next_token = response.get("NextContinuationToken")
496501

497502
# For each object key, create the directory on the local machine if needed, and then
498503
# download the file.
499504
downloaded_paths = []
505+
for dir_path in directories:
506+
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
500507
for key in keys:
501508
tail_s3_uri_path = os.path.basename(key)
502509
if not os.path.splitext(key_prefix)[1]:
503510
tail_s3_uri_path = os.path.relpath(key, key_prefix)
504511
destination_path = os.path.join(path, tail_s3_uri_path)
505512
if not os.path.exists(os.path.dirname(destination_path)):
506-
os.makedirs(os.path.dirname(destination_path))
513+
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
507514
s3.download_file(
508515
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
509516
)
@@ -5447,7 +5454,7 @@ def logs_for_job(self, job_name, wait=False, poll=10, log_type="All", timeout=No
54475454
exceptions.CapacityError: If the training job fails with CapacityError.
54485455
exceptions.UnexpectedStatusException: If waiting and the training job fails.
54495456
"""
5450-
_logs_for_job(self.boto_session, job_name, wait, poll, log_type, timeout)
5457+
_logs_for_job(self, job_name, wait, poll, log_type, timeout)
54515458

54525459
def logs_for_processing_job(self, job_name, wait=False, poll=10):
54535460
"""Display logs for a given processing job, optionally tailing them until the is complete.
@@ -7330,17 +7337,16 @@ def _rule_statuses_changed(current_statuses, last_statuses):
73307337

73317338

73327339
def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
7333-
boto_session, job_name, wait=False, poll=10, log_type="All", timeout=None
7340+
sagemaker_session, job_name, wait=False, poll=10, log_type="All", timeout=None
73347341
):
73357342
"""Display logs for a given training job, optionally tailing them until job is complete.
73367343
73377344
If the output is a tty or a Jupyter cell, it will be color-coded
73387345
based on which instance the log entry is from.
73397346
73407347
Args:
7341-
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
7342-
calls are delegated to (default: None). If not provided, one is created with
7343-
default AWS configuration chain.
7348+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
7349+
object, used for SageMaker interactions.
73447350
job_name (str): Name of the training job to display the logs for.
73457351
wait (bool): Whether to keep looking for new log entries until the job completes
73467352
(default: False).
@@ -7357,13 +7363,13 @@ def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
73577363
exceptions.CapacityError: If the training job fails with CapacityError.
73587364
exceptions.UnexpectedStatusException: If waiting and the training job fails.
73597365
"""
7360-
sagemaker_client = boto_session.client("sagemaker")
7366+
sagemaker_client = sagemaker_session.sagemaker_client
73617367
request_end_time = time.time() + timeout if timeout else None
73627368
description = sagemaker_client.describe_training_job(TrainingJobName=job_name)
73637369
print(secondary_training_status_message(description, None), end="")
73647370

73657371
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
7366-
boto_session, description, job="Training"
7372+
sagemaker_session.boto_session, description, job="Training"
73677373
)
73687374

73697375
state = _get_initial_job_state(description, "TrainingJobStatus", wait)

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_with_additional_dependencies(
207207
def cuberoot(x):
208208
from scipy.special import cbrt
209209

210-
return cbrt(27)
210+
return cbrt(x)
211211

212212
assert cuberoot(27) == 3
213213

@@ -742,7 +742,7 @@ def test_with_user_and_workdir_set_in_the_image(
742742
def cuberoot(x):
743743
from scipy.special import cbrt
744744

745-
return cbrt(27)
745+
return cbrt(x)
746746

747747
assert cuberoot(27) == 3
748748

0 commit comments

Comments
 (0)