Skip to content

Commit 9e7b4b5

Browse files
staubhpPayton Staubicywang86rui
authored
documentation: Correct type annotation for TrainingStep inputs (#2468)
* Correct type annotation for training step inputs * Add missing type hint * black-format Co-authored-by: Payton Staub <[email protected]> Co-authored-by: icywang86rui <[email protected]>
1 parent 3309f17 commit 9e7b4b5

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

src/sagemaker/workflow/steps.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@
1616
import abc
1717

1818
from enum import Enum
19-
from typing import Dict, List
19+
from typing import Dict, List, Union
2020

2121
import attr
2222

2323
from sagemaker.estimator import EstimatorBase, _TrainingJob
24-
from sagemaker.inputs import (
25-
CreateModelInput,
26-
TrainingInput,
27-
TransformInput,
28-
)
24+
from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput
2925
from sagemaker.model import Model
3026
from sagemaker.processing import (
3127
ProcessingInput,
@@ -145,7 +141,7 @@ def __init__(
145141
self,
146142
name: str,
147143
estimator: EstimatorBase,
148-
inputs: TrainingInput = None,
144+
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
149145
cache_config: CacheConfig = None,
150146
depends_on: List[str] = None,
151147
):
@@ -157,7 +153,23 @@ def __init__(
157153
Args:
158154
name (str): The name of the training step.
159155
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
160-
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
156+
inputs (str or dict or sagemaker.inputs.TrainingInput
157+
or sagemaker.inputs.FileSystemInput): Information
158+
about the training data. This can be one of three types:
159+
160+
* (str) the S3 location where training data is saved, or a file:// path in
161+
local mode.
162+
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) If using multiple
163+
channels for training data, you can specify a dict mapping channel names to
164+
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
165+
* (sagemaker.inputs.TrainingInput) - channel configuration for S3 data sources
166+
that can provide additional information as well as the path to the training
167+
dataset.
168+
See :func:`sagemaker.inputs.TrainingInput` for full details.
169+
* (sagemaker.inputs.FileSystemInput) - channel configuration for
170+
a file system data source that can provide additional information as well as
171+
the path to the training dataset.
172+
161173
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
162174
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
163175
depends on

0 commit comments

Comments
 (0)