|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | """This module provides the configuration classes used in `sagemaker.modules`.
|
14 | 14 |
|
15 |
| -Some of these classes are re-exported from `sagemaker-core.shapes`. For convinence, |
| 15 | +Some of these classes are re-exported from `sagemaker_core.shapes`. For convinence, |
16 | 16 | users can import these classes directly from `sagemaker.modules.configs`.
|
17 | 17 |
|
18 |
| -For more documentation on `sagemaker-core.shapes`, see: |
| 18 | +For more documentation on `sagemaker_core.shapes`, see: |
19 | 19 | - https://sagemaker-core.readthedocs.io/en/stable/#sagemaker-core-shapes
|
20 | 20 | """
|
21 | 21 |
|
22 | 22 | from __future__ import absolute_import
|
23 | 23 |
|
24 |
| -from typing import Optional, Dict, Any, List |
| 24 | +from typing import Optional, Union, Dict, Any, List |
25 | 25 | from pydantic import BaseModel, model_validator
|
26 | 26 |
|
| 27 | +import sagemaker_core.shapes as shapes |
| 28 | + |
| 29 | +# TODO: Can we add custom logic to some of these to set better defaults? |
27 | 30 | from sagemaker_core.shapes import (
|
28 |
| - ResourceConfig, |
29 | 31 | StoppingCondition,
|
| 32 | + RetryStrategy, |
30 | 33 | OutputDataConfig,
|
31 | 34 | Channel,
|
| 35 | + ShuffleConfig, |
32 | 36 | DataSource,
|
33 | 37 | S3DataSource,
|
34 | 38 | FileSystemDataSource,
|
35 | 39 | TrainingImageConfig,
|
36 |
| - VpcConfig, |
| 40 | + TrainingRepositoryAuthConfig, |
| 41 | + Tag, |
| 42 | + MetricDefinition, |
| 43 | + DebugHookConfig, |
| 44 | + CollectionConfiguration, |
| 45 | + DebugRuleConfiguration, |
| 46 | + ExperimentConfig, |
| 47 | + InfraCheckConfig, |
| 48 | + ProfilerConfig, |
| 49 | + ProfilerRuleConfiguration, |
| 50 | + RemoteDebugConfig, |
| 51 | + SessionChainingConfig, |
| 52 | + InstanceGroup, |
| 53 | + TensorBoardOutputConfig, |
| 54 | + CheckpointConfig, |
37 | 55 | )
|
38 | 56 |
|
39 | 57 | from sagemaker.modules import logger
|
| 58 | +from sagemaker.modules.utils import convert_unassigned_to_none |
40 | 59 |
|
41 | 60 | __all__ = [
|
42 | 61 | "SourceCodeConfig",
|
43 |
| - "ResourceConfig", |
| 62 | + "TorchDistributionConfig", |
| 63 | + "MPIDistributionConfig", |
| 64 | + "SMDistributedSettings", |
| 65 | + "DistributionConfig", |
44 | 66 | "StoppingCondition",
|
| 67 | + "RetryStrategy", |
45 | 68 | "OutputDataConfig",
|
46 | 69 | "Channel",
|
| 70 | + "ShuffleConfig", |
47 | 71 | "DataSource",
|
48 | 72 | "S3DataSource",
|
49 | 73 | "FileSystemDataSource",
|
50 | 74 | "TrainingImageConfig",
|
51 |
| - "VpcConfig", |
| 75 | + "TrainingRepositoryAuthConfig", |
| 76 | + "Tag", |
| 77 | + "MetricDefinition", |
| 78 | + "DebugHookConfig", |
| 79 | + "CollectionConfiguration", |
| 80 | + "DebugRuleConfiguration", |
| 81 | + "ExperimentConfig", |
| 82 | + "InfraCheckConfig", |
| 83 | + "ProfilerConfig", |
| 84 | + "ProfilerRuleConfiguration", |
| 85 | + "RemoteDebugConfig", |
| 86 | + "SessionChainingConfig", |
| 87 | + "InstanceGroup", |
| 88 | + "TensorBoardOutputConfig", |
| 89 | + "CheckpointConfig", |
| 90 | + "ComputeConfig", |
| 91 | + "NetworkingConfig", |
| 92 | + "InputData", |
52 | 93 | ]
|
53 | 94 |
|
54 | 95 |
|
@@ -161,14 +202,128 @@ class SourceCodeConfig(BaseModel):
|
161 | 202 | command (Optional[str]):
|
162 | 203 | The command(s) to execute in the training job container. Example: "python my_script.py".
|
163 | 204 | If not specified, entry_script must be provided.
|
164 |
| - distribution (Optional[Union[ |
165 |
| - MPIDistributionConfig, |
166 |
| - TorchDistributionConfig, |
167 |
| - ]]): |
168 |
| - The distribution configuration for the training job. |
169 | 205 | """
|
170 | 206 |
|
171 | 207 | source_dir: Optional[str] = None
|
172 | 208 | requirements: Optional[str] = None
|
173 | 209 | entry_script: Optional[str] = None
|
174 | 210 | command: Optional[str] = None
|
| 211 | + |
| 212 | + |
| 213 | +class ComputeConfig(shapes.ResourceConfig): |
| 214 | + """ComputeConfig. |
| 215 | +
|
| 216 | + The ComputeConfig is a subclass of `sagemaker_core.shapes.ResourceConfig` |
| 217 | + and allows the user to specify the compute resources for the training job. |
| 218 | +
|
| 219 | + Attributes: |
| 220 | + instance_type (Optional[str]): |
| 221 | + The ML compute instance type. For information about available instance types, |
| 222 | + see https://aws.amazon.com/sagemaker/pricing/. Default: ml.m5.xlarge |
| 223 | + instance_count (Optional[int]): The number of ML compute instances to use. For distributed |
| 224 | + training, provide a value greater than 1. Default: 1 |
| 225 | + volume_size_in_gb (Optional[int]): |
| 226 | + The size of the ML storage volume that you want to provision. ML storage volumes store |
| 227 | + model artifacts and incremental states. Training algorithms might also use the ML |
| 228 | + storage volume for scratch space. Default: 30 |
| 229 | + volume_kms_key_id (Optional[str]): |
| 230 | + The Amazon Web Services KMS key that SageMaker uses to encrypt data on the storage |
| 231 | + volume attached to the ML compute instance(s) that run the training job. |
| 232 | + keep_alive_period_in_seconds (Optional[int]): |
| 233 | + The duration of time in seconds to retain configured resources in a warm pool for |
| 234 | + subsequent training jobs. |
| 235 | + instance_groups (Optional[List[InstanceGroup]]): |
| 236 | + A list of instance groups for heterogeneous clusters to be used in the training job. |
| 237 | + enable_managed_spot_training (Optional[bool]): |
| 238 | + To train models using managed spot training, choose True. Managed spot training |
| 239 | + provides a fully managed and scalable infrastructure for training machine learning |
| 240 | + models. this option is useful when training jobs can be interrupted and when there |
| 241 | + is flexibility when the training job is run. |
| 242 | + """ |
| 243 | + |
| 244 | + volume_size_in_gb: Optional[int] = 30 |
| 245 | + enable_managed_spot_training: Optional[bool] = None |
| 246 | + |
| 247 | + @model_validator(mode="after") |
| 248 | + def _model_validator(self) -> "ComputeConfig": |
| 249 | + """Convert Unassigned values to None.""" |
| 250 | + return convert_unassigned_to_none(self) |
| 251 | + |
| 252 | + def _to_resource_config(self) -> shapes.ResourceConfig: |
| 253 | + """Convert to a sagemaker_core.shapes.ResourceConfig object.""" |
| 254 | + compute_config_dict = self.model_dump() |
| 255 | + resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys()) |
| 256 | + filtered_dict = { |
| 257 | + k: v for k, v in compute_config_dict.items() if k in resource_config_fields |
| 258 | + } |
| 259 | + return shapes.ResourceConfig(**filtered_dict) |
| 260 | + |
| 261 | + |
| 262 | +class NetworkingConfig(shapes.VpcConfig): |
| 263 | + """NetworkingConfig. |
| 264 | +
|
| 265 | + The NetworkingConifg is a subclass of `sagemaker_core.shapes.VpcConfig ` and |
| 266 | + allows the user to specify the networking configuration for the training job. |
| 267 | +
|
| 268 | + Attributes: |
| 269 | + security_group_ids (Optional[List[str]]): |
| 270 | + The VPC security group IDs, in the form sg-xxxxxxxx. Specify the |
| 271 | + security groups for the VPC that is specified in the Subnets field. |
| 272 | + subnets (Optional[List[str]]): |
| 273 | + The ID of the subnets in the VPC to which you want to connect your |
| 274 | + training job or model. |
| 275 | + enable_network_isolation (Optional[bool]): |
| 276 | + Isolates the training container. No inbound or outbound network calls can be made, |
| 277 | + except for calls between peers within a training cluster for distributed training. |
| 278 | + If you enable network isolation for training jobs that are configured to use a VPC, |
| 279 | + SageMaker downloads and uploads customer data and model artifacts through the |
| 280 | + specified VPC, but the training container does not have network access. |
| 281 | + enable_inter_container_traffic_encryption (Optional[bool]): |
| 282 | + To encrypt all communications between ML compute instances in distributed training |
| 283 | + choose True. Encryption provides greater security for distributed training, but |
| 284 | + training might take longer. How long it takes depends on the amount of |
| 285 | + communication between compute instances, especially if you use a deep learning |
| 286 | + algorithm in distributed training. |
| 287 | + """ |
| 288 | + |
| 289 | + enable_network_isolation: Optional[bool] = None |
| 290 | + enable_inter_container_traffic_encryption: Optional[bool] = None |
| 291 | + |
| 292 | + @model_validator(mode="after") |
| 293 | + def _model_validator(self) -> "NetworkingConfig": |
| 294 | + """Convert Unassigned values to None.""" |
| 295 | + return convert_unassigned_to_none(self) |
| 296 | + |
| 297 | + def _to_vpc_config(self) -> shapes.VpcConfig: |
| 298 | + """Convert to a sagemaker_core.shapes.VpcConfig object.""" |
| 299 | + compute_config_dict = self.model_dump() |
| 300 | + resource_config_fields = set(shapes.VpcConfig.__annotations__.keys()) |
| 301 | + filtered_dict = { |
| 302 | + k: v for k, v in compute_config_dict.items() if k in resource_config_fields |
| 303 | + } |
| 304 | + return shapes.VpcConfig(**filtered_dict) |
| 305 | + |
| 306 | + |
| 307 | +class InputData(BaseModel): |
| 308 | + """InputData. |
| 309 | +
|
| 310 | + This config allows the user to specify an input data source for the training job. |
| 311 | +
|
| 312 | + Will be found at `/opt/ml/input/data/<channel_name>` within the training container. |
| 313 | + For convience, can be referenced inside the training container like: |
| 314 | +
|
| 315 | + ```python |
| 316 | + import os |
| 317 | + input_data_dir = os.environ['SM_CHANNEL_<channel_name>'] |
| 318 | + ``` |
| 319 | +
|
| 320 | + Attributes: |
| 321 | + channel_name (str): |
| 322 | + The name of the input data source channel. |
| 323 | + data_source (Union[str, S3DataSource, FileSystemDataSource]): |
| 324 | + The data source for the channel. Can be an S3 URI string, local file path string, |
| 325 | + S3DataSource object, or FileSystemDataSource object. |
| 326 | + """ |
| 327 | + |
| 328 | + channel_name: str = None |
| 329 | + data_source: Union[str, FileSystemDataSource, S3DataSource] = None |
0 commit comments