Skip to content

Commit 3351ce9

Browse files
committed
change: fix docstrings, improve readability
1 parent a3cabe8 commit 3351ce9

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

src/sagemaker/estimator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import logging
1818
import os
19-
from typing import Dict
19+
from typing import Any, Dict
2020
import uuid
2121
from abc import ABCMeta, abstractmethod
2222

@@ -364,7 +364,7 @@ def __init__(
364364
try to use either CodeCommit credential helper or local
365365
credential storage for authentication.
366366
hyperparameters (dict): Dictionary containing the hyperparameters to
367-
initialize this estimator with.
367+
initialize this estimator with. (Default: None).
368368
container_log_level (int): Log level to use within the container
369369
(default: logging.INFO). Valid values are defined in the Python
370370
logging module.
@@ -375,7 +375,7 @@ def __init__(
375375
If not specified, the default ``code location`` is s3://output_bucket/job-name/.
376376
entry_point (str): Path (absolute or relative) to the local Python
377377
source file which should be executed as the entry point to
378-
training. If ``source_dir`` is specified, then ``entry_point``
378+
training. (Default: None). If ``source_dir`` is specified, then ``entry_point``
379379
must point to a file located at the root of ``source_dir``.
380380
If 'git_config' is provided, 'entry_point' should be
381381
a relative location to the Python source file in the Git repo.
@@ -583,7 +583,7 @@ def _get_or_create_name(self, name=None):
583583
return name_from_base(self.base_job_name)
584584

585585
@staticmethod
586-
def _json_encode_hyperparameters(hyperparameters: dict) -> dict:
586+
def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:
587587
"""Applies Json encoding for certain Hyperparameter types, returns hyperparameters.
588588
589589
Args:
@@ -679,7 +679,7 @@ def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None:
679679
self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(hyperparams))
680680

681681
def _stage_user_code_in_s3(self) -> str:
682-
"""Upload the user training script to s3 and return the location.
682+
"""Upload the user training script to s3 and return the s3 URI.
683683
684684
Returns: s3 uri
685685
"""
@@ -2691,7 +2691,7 @@ def _model_entry_point(self):
26912691
return None
26922692

26932693
def set_hyperparameters(self, **kwargs):
2694-
"""Sets hyperparameters."""
2694+
"""Escape the dict argument as JSON, update the private hyperparameter attribute."""
26952695
self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(kwargs))
26962696

26972697
def hyperparameters(self):

src/sagemaker/model.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,12 +399,10 @@ def prepare_container_def(
399399
)
400400
deploy_env = copy.deepcopy(self.env)
401401
if self.source_dir or self.dependencies or self.entry_point or self.git_config:
402-
self._upload_code(
403-
deploy_key_prefix,
404-
repack=self.source_dir
405-
and self.entry_point
406-
and not (self.key_prefix or self.git_config),
402+
is_repack = (
403+
self.source_dir and self.entry_point and not (self.key_prefix or self.git_config)
407404
)
405+
self._upload_code(deploy_key_prefix, repack=is_repack)
408406
deploy_env.update(self._script_mode_env_vars())
409407
return sagemaker.container_def(
410408
self.image_uri, self.model_data, deploy_env, image_config=self.image_config

0 commit comments

Comments
 (0)