16
16
import abc
17
17
18
18
from enum import Enum
19
- from typing import Dict , List
19
+ from typing import Dict , List , Union
20
20
21
21
import attr
22
22
@@ -145,7 +145,7 @@ def __init__(
145
145
self ,
146
146
name : str ,
147
147
estimator : EstimatorBase ,
148
- inputs : TrainingInput = None ,
148
+ inputs : Union [ TrainingInput , dict , str ] = None ,
149
149
cache_config : CacheConfig = None ,
150
150
depends_on : List [str ] = None ,
151
151
):
@@ -157,7 +157,22 @@ def __init__(
157
157
Args:
158
158
name (str): The name of the training step.
159
159
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
160
- inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
160
+ inputs (str or dict or sagemaker.inputs.TrainingInput): Information
161
+ about the training data. This can be one of three types:
162
+
163
+ * (str) the S3 location where training data is saved, or a file:// path in
164
+ local mode.
165
+ * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) If using multiple
166
+ channels for training data, you can specify a dict mapping channel names to
167
+ strings or :func:`~sagemaker.inputs.TrainingInput` objects.
168
+ * (sagemaker.inputs.TrainingInput) - channel configuration for S3 data sources
169
+ that can provide additional information as well as the path to the training
170
+ dataset.
171
+ See :func:`sagemaker.inputs.TrainingInput` for full details.
172
+ * (sagemaker.session.FileSystemInput) - channel configuration for
173
+ a file system data source that can provide additional information as well as
174
+ the path to the training dataset.
175
+
161
176
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
162
177
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
163
178
depends on
0 commit comments