Skip to content

Commit b33d6d2

Browse files
authored
fix: typo and merge with master branch (#4649)
1 parent 72342be commit b33d6d2

23 files changed

+205
-27
lines changed

CHANGELOG.md

+23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,28 @@
11
# Changelog
22

3+
## v2.218.0 (2024-05-01)
4+
5+
### Features
6+
7+
* set default allow_pickle param to False
8+
9+
### Bug Fixes and Other Changes
10+
11+
* properly close files in lineage queries and tests
12+
13+
## v2.217.0 (2024-04-24)
14+
15+
### Features
16+
17+
* support session tag chaining for training job
18+
19+
### Bug Fixes and Other Changes
20+
21+
* Add Triton v24.03 URI
22+
* mainline alt config parsing
23+
* Fix tox installs
24+
* Add PT 2.2 Graviton Inference DLC
25+
326
## v2.216.1 (2024-04-22)
427

528
### Bug Fixes and Other Changes

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.216.2.dev0
1+
2.218.1.dev0

src/sagemaker/base_deserializers.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,14 @@ class NumpyDeserializer(SimpleBaseDeserializer):
196196
single array.
197197
"""
198198

199-
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
199+
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=False):
200200
"""Initialize a ``NumpyDeserializer`` instance.
201201
202202
Args:
203203
dtype (str): The dtype of the data (default: None).
204204
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
205205
is expected from the inference endpoint (default: "application/x-npy").
206-
allow_pickle (bool): Allow loading pickled object arrays (default: True).
206+
allow_pickle (bool): Allow loading pickled object arrays (default: False).
207207
"""
208208
super(NumpyDeserializer, self).__init__(accept=accept)
209209
self.dtype = dtype
@@ -227,10 +227,21 @@ def deserialize(self, stream, content_type):
227227
if content_type == "application/json":
228228
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
229229
if content_type == "application/x-npy":
230-
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
230+
try:
231+
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
232+
except ValueError as ve:
233+
raise ValueError(
234+
"Please set the param allow_pickle=True \
235+
to deserialize pickle objects in NumpyDeserializer"
236+
).with_traceback(ve.__traceback__)
231237
if content_type == "application/x-npz":
232238
try:
233239
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
240+
except ValueError as ve:
241+
raise ValueError(
242+
"Please set the param allow_pickle=True \
243+
to deserialize pickle objectsin NumpyDeserializer"
244+
).with_traceback(ve.__traceback__)
234245
finally:
235246
stream.close()
236247
finally:

src/sagemaker/estimator.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def __init__(
181181
container_arguments: Optional[List[str]] = None,
182182
disable_output_compression: bool = False,
183183
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
184+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
184185
**kwargs,
185186
):
186187
"""Initialize an ``EstimatorBase`` instance.
@@ -544,7 +545,9 @@ def __init__(
544545
enable_infra_check (bool or PipelineVariable): Optional.
545546
Specifies whether it is running Sagemaker built-in infra check jobs.
546547
enable_remote_debug (bool or PipelineVariable): Optional.
547-
Specifies whether RemoteDebug is enabled for the training job
548+
Specifies whether RemoteDebug is enabled for the training job.
549+
enable_session_tag_chaining (bool or PipelineVariable): Optional.
550+
Specifies whether SessionTagChaining is enabled for the training job.
548551
"""
549552
instance_count = renamed_kwargs(
550553
"train_instance_count", "instance_count", instance_count, kwargs
@@ -785,6 +788,8 @@ def __init__(
785788

786789
self._enable_remote_debug = enable_remote_debug
787790

791+
self._enable_session_tag_chaining = enable_session_tag_chaining
792+
788793
@abstractmethod
789794
def training_image_uri(self):
790795
"""Return the Docker image to use for training.
@@ -2318,6 +2323,14 @@ def get_remote_debug_config(self):
23182323
else {"EnableRemoteDebug": self._enable_remote_debug}
23192324
)
23202325

2326+
def get_session_chaining_config(self):
2327+
"""dict: Return the configuration of SessionChaining"""
2328+
return (
2329+
None
2330+
if self._enable_session_tag_chaining is None
2331+
else {"EnableSessionTagChaining": self._enable_session_tag_chaining}
2332+
)
2333+
23212334
def enable_remote_debug(self):
23222335
"""Enable remote debug for a training job."""
23232336
self._update_remote_debug(True)
@@ -2574,6 +2587,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25742587
if estimator.get_remote_debug_config() is not None:
25752588
train_args["remote_debug_config"] = estimator.get_remote_debug_config()
25762589

2590+
if estimator.get_session_chaining_config() is not None:
2591+
train_args["session_chaining_config"] = estimator.get_session_chaining_config()
2592+
25772593
return train_args
25782594

25792595
@classmethod
@@ -2766,6 +2782,7 @@ def __init__(
27662782
disable_output_compression: bool = False,
27672783
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
27682784
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
2785+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
27692786
**kwargs,
27702787
):
27712788
"""Initialize an ``Estimator`` instance.
@@ -3129,6 +3146,8 @@ def __init__(
31293146
Specifies whether it is running Sagemaker built-in infra check jobs.
31303147
enable_remote_debug (bool or PipelineVariable): Optional.
31313148
Specifies whether RemoteDebug is enabled for the training job
3149+
enable_session_tag_chaining (bool or PipelineVariable): Optional.
3150+
Specifies whether SessionTagChaining is enabled for the training job
31323151
"""
31333152
self.image_uri = image_uri
31343153
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
@@ -3181,6 +3200,7 @@ def __init__(
31813200
container_arguments=container_arguments,
31823201
disable_output_compression=disable_output_compression,
31833202
enable_remote_debug=enable_remote_debug,
3203+
enable_session_tag_chaining=enable_session_tag_chaining,
31843204
**kwargs,
31853205
)
31863206

src/sagemaker/image_uri_config/sagemaker-tritonserver.json

+34-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"inference"
88
],
99
"versions": {
10-
"23.12": {
10+
"24.03": {
1111
"registries": {
1212
"af-south-1": "626614931356",
1313
"il-central-1": "780543022126",
@@ -37,7 +37,7 @@
3737
"ca-west-1": "204538143572"
3838
},
3939
"repository": "sagemaker-tritonserver",
40-
"tag_prefix": "23.12-py3"
40+
"tag_prefix": "24.03-py3"
4141
},
4242
"24.01": {
4343
"registries": {
@@ -70,6 +70,38 @@
7070
},
7171
"repository": "sagemaker-tritonserver",
7272
"tag_prefix": "24.01-py3"
73+
},
74+
"23.12": {
75+
"registries": {
76+
"af-south-1": "626614931356",
77+
"il-central-1": "780543022126",
78+
"ap-east-1": "871362719292",
79+
"ap-northeast-1": "763104351884",
80+
"ap-northeast-2": "763104351884",
81+
"ap-northeast-3": "364406365360",
82+
"ap-south-1": "763104351884",
83+
"ap-southeast-1": "763104351884",
84+
"ap-southeast-2": "763104351884",
85+
"ap-southeast-3": "907027046896",
86+
"ca-central-1": "763104351884",
87+
"cn-north-1": "727897471807",
88+
"cn-northwest-1": "727897471807",
89+
"eu-central-1": "763104351884",
90+
"eu-north-1": "763104351884",
91+
"eu-west-1": "763104351884",
92+
"eu-west-2": "763104351884",
93+
"eu-west-3": "763104351884",
94+
"eu-south-1": "692866216735",
95+
"me-south-1": "217643126080",
96+
"sa-east-1": "763104351884",
97+
"us-east-1": "763104351884",
98+
"us-east-2": "763104351884",
99+
"us-west-1": "763104351884",
100+
"us-west-2": "763104351884",
101+
"ca-west-1": "204538143572"
102+
},
103+
"repository": "sagemaker-tritonserver",
104+
"tag_prefix": "23.12-py3"
73105
}
74106
}
75107
}

src/sagemaker/jumpstart/estimator.py

+4
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
disable_output_compression: Optional[bool] = None,
113113
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
114114
config_name: Optional[str] = None,
115+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
115116
):
116117
"""Initializes a ``JumpStartEstimator``.
117118
@@ -505,6 +506,8 @@ def __init__(
505506
Specifies whether RemoteDebug is enabled for the training job
506507
config_name (Optional[str]):
507508
Name of the training configuration to apply to the Estimator. (Default: None).
509+
enable_session_tag_chaining (bool or PipelineVariable): Optional.
510+
Specifies whether SessionTagChaining is enabled for the training job
508511
509512
Raises:
510513
ValueError: If the model ID is not recognized by JumpStart.
@@ -584,6 +587,7 @@ def _validate_model_id_and_get_type_hook():
584587
enable_infra_check=enable_infra_check,
585588
enable_remote_debug=enable_remote_debug,
586589
config_name=config_name,
590+
enable_session_tag_chaining=enable_session_tag_chaining,
587591
)
588592

589593
self.model_id = estimator_init_kwargs.model_id

src/sagemaker/jumpstart/factory/estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def get_init_kwargs(
131131
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
132132
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
133133
config_name: Optional[str] = None,
134+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
134135
) -> JumpStartEstimatorInitKwargs:
135136
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
136137

@@ -190,6 +191,7 @@ def get_init_kwargs(
190191
enable_infra_check=enable_infra_check,
191192
enable_remote_debug=enable_remote_debug,
192193
config_name=config_name,
194+
enable_session_tag_chaining=enable_session_tag_chaining,
193195
)
194196

195197
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)

src/sagemaker/jumpstart/session_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def get_model_info_from_training_job(
219219
model_id,
220220
inferred_model_version,
221221
inference_config_name,
222-
trainig_config_name,
222+
training_config_name,
223223
) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session)
224224

225225
model_version = inferred_model_version or None
@@ -231,4 +231,4 @@ def get_model_info_from_training_job(
231231
"for this training job."
232232
)
233233

234-
return model_id, model_version, inference_config_name, trainig_config_name
234+
return model_id, model_version, inference_config_name, training_config_name

src/sagemaker/jumpstart/types.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,7 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
10781078
"resolved_metadata_config",
10791079
"config_name",
10801080
"default_inference_config",
1081-
"default_incremental_trainig_config",
1081+
"default_incremental_training_config",
10821082
"supported_inference_configs",
10831083
"supported_incremental_training_configs",
10841084
]
@@ -1114,7 +1114,7 @@ def __init__(
11141114
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
11151115
self.config_name: Optional[str] = config_name
11161116
self.default_inference_config: Optional[str] = config.get("default_inference_config")
1117-
self.default_incremental_trainig_config: Optional[str] = config.get(
1117+
self.default_incremental_training_config: Optional[str] = config.get(
11181118
"default_incremental_training_config"
11191119
)
11201120
self.supported_inference_configs: Optional[List[str]] = config.get(
@@ -1775,6 +1775,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
17751775
"enable_infra_check",
17761776
"enable_remote_debug",
17771777
"config_name",
1778+
"enable_session_tag_chaining",
17781779
]
17791780

17801781
SERIALIZATION_EXCLUSION_SET = {
@@ -1844,6 +1845,7 @@ def __init__(
18441845
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
18451846
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
18461847
config_name: Optional[str] = None,
1848+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
18471849
) -> None:
18481850
"""Instantiates JumpStartEstimatorInitKwargs object."""
18491851

@@ -1904,6 +1906,7 @@ def __init__(
19041906
self.enable_infra_check = enable_infra_check
19051907
self.enable_remote_debug = enable_remote_debug
19061908
self.config_name = config_name
1909+
self.enable_session_tag_chaining = enable_session_tag_chaining
19071910

19081911

19091912
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

src/sagemaker/lineage/query.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def _get_legend_line(self, component_name):
335335

336336
def _add_legend(self, path):
337337
"""Embed legend to html file generated by pyvis."""
338-
f = open(path, "r")
339-
content = self.BeautifulSoup(f, "html.parser")
338+
with open(path, "r") as f:
339+
content = self.BeautifulSoup(f, "html.parser")
340340

341341
legend = """
342342
<div style="display: inline-block; font-size: 1vw; font-family: verdana;

src/sagemaker/session.py

+24
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,7 @@ def train( # noqa: C901
758758
environment: Optional[Dict[str, str]] = None,
759759
retry_strategy=None,
760760
remote_debug_config=None,
761+
session_chaining_config=None,
761762
):
762763
"""Create an Amazon SageMaker training job.
763764
@@ -877,6 +878,15 @@ def train( # noqa: C901
877878
remote_debug_config = {
878879
"EnableRemoteDebug": True,
879880
}
881+
session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
882+
The dict can contain 'EnableSessionTagChaining'(bool).
883+
For example,
884+
885+
.. code:: python
886+
887+
session_chaining_config = {
888+
"EnableSessionTagChaining": True,
889+
}
880890
environment (dict[str, str]) : Environment variables to be set for
881891
use during training job (default: ``None``)
882892
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
@@ -970,6 +980,7 @@ def train( # noqa: C901
970980
profiler_rule_configs=profiler_rule_configs,
971981
profiler_config=inferred_profiler_config,
972982
remote_debug_config=remote_debug_config,
983+
session_chaining_config=session_chaining_config,
973984
environment=environment,
974985
retry_strategy=retry_strategy,
975986
)
@@ -1013,6 +1024,7 @@ def _get_train_request( # noqa: C901
10131024
profiler_rule_configs=None,
10141025
profiler_config=None,
10151026
remote_debug_config=None,
1027+
session_chaining_config=None,
10161028
environment=None,
10171029
retry_strategy=None,
10181030
):
@@ -1133,6 +1145,15 @@ def _get_train_request( # noqa: C901
11331145
remote_debug_config = {
11341146
"EnableRemoteDebug": True,
11351147
}
1148+
session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
1149+
The dict can contain 'EnableSessionTagChaining'(bool).
1150+
For example,
1151+
1152+
.. code:: python
1153+
1154+
session_chaining_config = {
1155+
"EnableSessionTagChaining": True,
1156+
}
11361157
environment (dict[str, str]) : Environment variables to be set for
11371158
use during training job (default: ``None``)
11381159
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
@@ -1239,6 +1260,9 @@ def _get_train_request( # noqa: C901
12391260
if remote_debug_config is not None:
12401261
train_request["RemoteDebugConfig"] = remote_debug_config
12411262

1263+
if session_chaining_config is not None:
1264+
train_request["SessionChainingConfig"] = session_chaining_config
1265+
12421266
if retry_strategy is not None:
12431267
train_request["RetryStrategy"] = retry_strategy
12441268

tests/data/sip/training.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def main():
7373
)
7474

7575
model_dir = os.environ.get("SM_MODEL_DIR")
76-
pkl.dump(bst, open(model_dir + "/model.bin", "wb"))
76+
with open(model_dir + "/model.bin", "wb") as f:
77+
pkl.dump(bst, f)
7778

7879

7980
if __name__ == "__main__":

0 commit comments

Comments
 (0)