Skip to content

Commit ae3e3d2

Browse files
beniericpintaoz-aws
authored andcommitted
Add path to set Additional Settings in ModelTrainer (#1555)
1 parent 25486e6 commit ae3e3d2

File tree

6 files changed

+699
-208
lines changed

6 files changed

+699
-208
lines changed

src/sagemaker/modules/configs.py

+167-12
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,84 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module provides the configuration classes used in `sagemaker.modules`.
1414
15-
Some of these classes are re-exported from `sagemaker-core.shapes`. For convinence,
15+
Some of these classes are re-exported from `sagemaker_core.shapes`. For convinence,
1616
users can import these classes directly from `sagemaker.modules.configs`.
1717
18-
For more documentation on `sagemaker-core.shapes`, see:
18+
For more documentation on `sagemaker_core.shapes`, see:
1919
- https://sagemaker-core.readthedocs.io/en/stable/#sagemaker-core-shapes
2020
"""
2121

2222
from __future__ import absolute_import
2323

24-
from typing import Optional, Dict, Any, List
24+
from typing import Optional, Union, Dict, Any, List
2525
from pydantic import BaseModel, model_validator
2626

27+
import sagemaker_core.shapes as shapes
28+
29+
# TODO: Can we add custom logic to some of these to set better defaults?
2730
from sagemaker_core.shapes import (
28-
ResourceConfig,
2931
StoppingCondition,
32+
RetryStrategy,
3033
OutputDataConfig,
3134
Channel,
35+
ShuffleConfig,
3236
DataSource,
3337
S3DataSource,
3438
FileSystemDataSource,
3539
TrainingImageConfig,
36-
VpcConfig,
40+
TrainingRepositoryAuthConfig,
41+
Tag,
42+
MetricDefinition,
43+
DebugHookConfig,
44+
CollectionConfiguration,
45+
DebugRuleConfiguration,
46+
ExperimentConfig,
47+
InfraCheckConfig,
48+
ProfilerConfig,
49+
ProfilerRuleConfiguration,
50+
RemoteDebugConfig,
51+
SessionChainingConfig,
52+
InstanceGroup,
53+
TensorBoardOutputConfig,
54+
CheckpointConfig,
3755
)
3856

3957
from sagemaker.modules import logger
58+
from sagemaker.modules.utils import convert_unassigned_to_none
4059

4160
__all__ = [
4261
"SourceCodeConfig",
43-
"ResourceConfig",
62+
"TorchDistributionConfig",
63+
"MPIDistributionConfig",
64+
"SMDistributedSettings",
65+
"DistributionConfig",
4466
"StoppingCondition",
67+
"RetryStrategy",
4568
"OutputDataConfig",
4669
"Channel",
70+
"ShuffleConfig",
4771
"DataSource",
4872
"S3DataSource",
4973
"FileSystemDataSource",
5074
"TrainingImageConfig",
51-
"VpcConfig",
75+
"TrainingRepositoryAuthConfig",
76+
"Tag",
77+
"MetricDefinition",
78+
"DebugHookConfig",
79+
"CollectionConfiguration",
80+
"DebugRuleConfiguration",
81+
"ExperimentConfig",
82+
"InfraCheckConfig",
83+
"ProfilerConfig",
84+
"ProfilerRuleConfiguration",
85+
"RemoteDebugConfig",
86+
"SessionChainingConfig",
87+
"InstanceGroup",
88+
"TensorBoardOutputConfig",
89+
"CheckpointConfig",
90+
"ComputeConfig",
91+
"NetworkingConfig",
92+
"InputData",
5293
]
5394

5495

@@ -161,14 +202,128 @@ class SourceCodeConfig(BaseModel):
161202
command (Optional[str]):
162203
The command(s) to execute in the training job container. Example: "python my_script.py".
163204
If not specified, entry_script must be provided.
164-
distribution (Optional[Union[
165-
MPIDistributionConfig,
166-
TorchDistributionConfig,
167-
]]):
168-
The distribution configuration for the training job.
169205
"""
170206

171207
source_dir: Optional[str] = None
172208
requirements: Optional[str] = None
173209
entry_script: Optional[str] = None
174210
command: Optional[str] = None
211+
212+
213+
class ComputeConfig(shapes.ResourceConfig):
214+
"""ComputeConfig.
215+
216+
The ComputeConfig is a subclass of `sagemaker_core.shapes.ResourceConfig`
217+
and allows the user to specify the compute resources for the training job.
218+
219+
Attributes:
220+
instance_type (Optional[str]):
221+
The ML compute instance type. For information about available instance types,
222+
see https://aws.amazon.com/sagemaker/pricing/. Default: ml.m5.xlarge
223+
instance_count (Optional[int]): The number of ML compute instances to use. For distributed
224+
training, provide a value greater than 1. Default: 1
225+
volume_size_in_gb (Optional[int]):
226+
The size of the ML storage volume that you want to provision. ML storage volumes store
227+
model artifacts and incremental states. Training algorithms might also use the ML
228+
storage volume for scratch space. Default: 30
229+
volume_kms_key_id (Optional[str]):
230+
The Amazon Web Services KMS key that SageMaker uses to encrypt data on the storage
231+
volume attached to the ML compute instance(s) that run the training job.
232+
keep_alive_period_in_seconds (Optional[int]):
233+
The duration of time in seconds to retain configured resources in a warm pool for
234+
subsequent training jobs.
235+
instance_groups (Optional[List[InstanceGroup]]):
236+
A list of instance groups for heterogeneous clusters to be used in the training job.
237+
enable_managed_spot_training (Optional[bool]):
238+
To train models using managed spot training, choose True. Managed spot training
239+
provides a fully managed and scalable infrastructure for training machine learning
240+
models. this option is useful when training jobs can be interrupted and when there
241+
is flexibility when the training job is run.
242+
"""
243+
244+
volume_size_in_gb: Optional[int] = 30
245+
enable_managed_spot_training: Optional[bool] = None
246+
247+
@model_validator(mode="after")
248+
def _model_validator(self) -> "ComputeConfig":
249+
"""Convert Unassigned values to None."""
250+
return convert_unassigned_to_none(self)
251+
252+
def _to_resource_config(self) -> shapes.ResourceConfig:
253+
"""Convert to a sagemaker_core.shapes.ResourceConfig object."""
254+
compute_config_dict = self.model_dump()
255+
resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys())
256+
filtered_dict = {
257+
k: v for k, v in compute_config_dict.items() if k in resource_config_fields
258+
}
259+
return shapes.ResourceConfig(**filtered_dict)
260+
261+
262+
class NetworkingConfig(shapes.VpcConfig):
263+
"""NetworkingConfig.
264+
265+
The NetworkingConifg is a subclass of `sagemaker_core.shapes.VpcConfig ` and
266+
allows the user to specify the networking configuration for the training job.
267+
268+
Attributes:
269+
security_group_ids (Optional[List[str]]):
270+
The VPC security group IDs, in the form sg-xxxxxxxx. Specify the
271+
security groups for the VPC that is specified in the Subnets field.
272+
subnets (Optional[List[str]]):
273+
The ID of the subnets in the VPC to which you want to connect your
274+
training job or model.
275+
enable_network_isolation (Optional[bool]):
276+
Isolates the training container. No inbound or outbound network calls can be made,
277+
except for calls between peers within a training cluster for distributed training.
278+
If you enable network isolation for training jobs that are configured to use a VPC,
279+
SageMaker downloads and uploads customer data and model artifacts through the
280+
specified VPC, but the training container does not have network access.
281+
enable_inter_container_traffic_encryption (Optional[bool]):
282+
To encrypt all communications between ML compute instances in distributed training
283+
choose True. Encryption provides greater security for distributed training, but
284+
training might take longer. How long it takes depends on the amount of
285+
communication between compute instances, especially if you use a deep learning
286+
algorithm in distributed training.
287+
"""
288+
289+
enable_network_isolation: Optional[bool] = None
290+
enable_inter_container_traffic_encryption: Optional[bool] = None
291+
292+
@model_validator(mode="after")
293+
def _model_validator(self) -> "NetworkingConfig":
294+
"""Convert Unassigned values to None."""
295+
return convert_unassigned_to_none(self)
296+
297+
def _to_vpc_config(self) -> shapes.VpcConfig:
298+
"""Convert to a sagemaker_core.shapes.VpcConfig object."""
299+
compute_config_dict = self.model_dump()
300+
resource_config_fields = set(shapes.VpcConfig.__annotations__.keys())
301+
filtered_dict = {
302+
k: v for k, v in compute_config_dict.items() if k in resource_config_fields
303+
}
304+
return shapes.VpcConfig(**filtered_dict)
305+
306+
307+
class InputData(BaseModel):
308+
"""InputData.
309+
310+
This config allows the user to specify an input data source for the training job.
311+
312+
Will be found at `/opt/ml/input/data/<channel_name>` within the training container.
313+
For convience, can be referenced inside the training container like:
314+
315+
```python
316+
import os
317+
input_data_dir = os.environ['SM_CHANNEL_<channel_name>']
318+
```
319+
320+
Attributes:
321+
channel_name (str):
322+
The name of the input data source channel.
323+
data_source (Union[str, S3DataSource, FileSystemDataSource]):
324+
The data source for the channel. Can be an S3 URI string, local file path string,
325+
S3DataSource object, or FileSystemDataSource object.
326+
"""
327+
328+
channel_name: str = None
329+
data_source: Union[str, FileSystemDataSource, S3DataSource] = None

0 commit comments

Comments
 (0)