diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 10ac37a322..4f91b73972 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -556,7 +556,7 @@ def __init__( self.code_location = code_location self.entry_point = entry_point self.dependencies = dependencies or [] - self.uploaded_code = None + self.uploaded_code: Optional[UploadedCode] = None self.tags = add_jumpstart_tags( tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir ) @@ -839,7 +839,7 @@ def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None: self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(hyperparams)) - def _stage_user_code_in_s3(self) -> str: + def _stage_user_code_in_s3(self) -> UploadedCode: """Uploads the user training script to S3 and returns the S3 URI. Returns: S3 URI @@ -3135,7 +3135,7 @@ def __init__( self.git_config = git_config self.source_dir = source_dir self.dependencies = dependencies or [] - self.uploaded_code = None + self.uploaded_code: Optional[UploadedCode] = None self.container_log_level = container_log_level self.code_location = code_location diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 96ff08e877..c23f9016d8 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -37,8 +37,8 @@ _TAR_SOURCE_FILENAME = "source.tar.gz" -UploadedCode = namedtuple("UserCode", ["s3_prefix", "script_name"]) -"""sagemaker.fw_utils.UserCode: An object containing the S3 prefix and script name. +UploadedCode = namedtuple("UploadedCode", ["s3_prefix", "script_name"]) +"""sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name. This is for the source code used for the entry point with an ``Estimator``. It can be instantiated with positional or keyword arguments. """ @@ -398,7 +398,7 @@ def tar_and_upload_dir( kms_key=None, s3_resource=None, settings: Optional[SessionSettings] = None, -): +) -> UploadedCode: """Package source files and upload a compress tar file to S3. The S3 location will be ``s3:///s3_key_prefix/sourcedir.tar.gz``. @@ -429,7 +429,7 @@ def tar_and_upload_dir( of the SageMaker ``Session``, can be used to override the default encryption behavior (default: None). Returns: - sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and + sagemaker.fw_utils.UploadedCode: An object with the S3 bucket and key (S3 prefix) and script name. """ if directory and (is_pipeline_variable(directory) or directory.lower().startswith("s3://")):