Skip to content

Commit 77fae44

Browse files
Captainiabenieric
authored andcommitted
fix: typo and merge with master branch (aws#4649)
1 parent 5a424ea commit 77fae44

File tree

5 files changed

+14
-5
lines changed

5 files changed

+14
-5
lines changed

src/sagemaker/jumpstart/estimator.py

+4
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
114114
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
115115
config_name: Optional[str] = None,
116+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
116117
):
117118
"""Initializes a ``JumpStartEstimator``.
118119
@@ -508,6 +509,8 @@ def __init__(
508509
Specifies whether SessionTagChaining is enabled for the training job
509510
config_name (Optional[str]):
510511
Name of the training configuration to apply to the Estimator. (Default: None).
512+
enable_session_tag_chaining (bool or PipelineVariable): Optional.
513+
Specifies whether SessionTagChaining is enabled for the training job
511514
512515
Raises:
513516
ValueError: If the model ID is not recognized by JumpStart.
@@ -588,6 +591,7 @@ def _validate_model_id_and_get_type_hook():
588591
enable_remote_debug=enable_remote_debug,
589592
enable_session_tag_chaining=enable_session_tag_chaining,
590593
config_name=config_name,
594+
enable_session_tag_chaining=enable_session_tag_chaining,
591595
)
592596

593597
self.model_id = estimator_init_kwargs.model_id

src/sagemaker/jumpstart/factory/estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def get_init_kwargs(
132132
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
133133
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
134134
config_name: Optional[str] = None,
135+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
135136
) -> JumpStartEstimatorInitKwargs:
136137
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
137138

@@ -192,6 +193,7 @@ def get_init_kwargs(
192193
enable_remote_debug=enable_remote_debug,
193194
enable_session_tag_chaining=enable_session_tag_chaining,
194195
config_name=config_name,
196+
enable_session_tag_chaining=enable_session_tag_chaining,
195197
)
196198

197199
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)

src/sagemaker/jumpstart/session_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def get_model_info_from_training_job(
219219
model_id,
220220
inferred_model_version,
221221
inference_config_name,
222-
trainig_config_name,
222+
training_config_name,
223223
) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session)
224224

225225
model_version = inferred_model_version or None
@@ -231,4 +231,4 @@ def get_model_info_from_training_job(
231231
"for this training job."
232232
)
233233

234-
return model_id, model_version, inference_config_name, trainig_config_name
234+
return model_id, model_version, inference_config_name, training_config_name

src/sagemaker/jumpstart/types.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,7 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
10781078
"resolved_metadata_config",
10791079
"config_name",
10801080
"default_inference_config",
1081-
"default_incremental_trainig_config",
1081+
"default_incremental_training_config",
10821082
"supported_inference_configs",
10831083
"supported_incremental_training_configs",
10841084
]
@@ -1114,7 +1114,7 @@ def __init__(
11141114
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
11151115
self.config_name: Optional[str] = config_name
11161116
self.default_inference_config: Optional[str] = config.get("default_inference_config")
1117-
self.default_incremental_trainig_config: Optional[str] = config.get(
1117+
self.default_incremental_training_config: Optional[str] = config.get(
11181118
"default_incremental_training_config"
11191119
)
11201120
self.supported_inference_configs: Optional[List[str]] = config.get(
@@ -1776,6 +1776,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
17761776
"enable_remote_debug",
17771777
"enable_session_tag_chaining",
17781778
"config_name",
1779+
"enable_session_tag_chaining",
17791780
]
17801781

17811782
SERIALIZATION_EXCLUSION_SET = {
@@ -1846,6 +1847,7 @@ def __init__(
18461847
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
18471848
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
18481849
config_name: Optional[str] = None,
1850+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
18491851
) -> None:
18501852
"""Instantiates JumpStartEstimatorInitKwargs object."""
18511853

@@ -1907,6 +1909,7 @@ def __init__(
19071909
self.enable_remote_debug = enable_remote_debug
19081910
self.enable_session_tag_chaining = enable_session_tag_chaining
19091911
self.config_name = config_name
1912+
self.enable_session_tag_chaining = enable_session_tag_chaining
19101913

19111914

19121915
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_gated_model_training_v2(setup):
150150
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
151151
environment={"accept_eula": "true"},
152152
max_run=259200, # avoid exceeding resource limits
153-
tolerate_vulnerable_model=True, # TODO: remove once vulnerbility is patched
153+
tolerate_vulnerable_model=True, # tolerate old version of model
154154
)
155155

156156
# uses ml.g5.12xlarge instance

0 commit comments

Comments
 (0)