16
16
import abc
17
17
18
18
from enum import Enum
19
- from typing import Dict , List
19
+ from typing import Dict , List , Union
20
20
21
21
import attr
22
22
23
23
from sagemaker .estimator import EstimatorBase , _TrainingJob
24
- from sagemaker .inputs import (
25
- CreateModelInput ,
26
- TrainingInput ,
27
- TransformInput ,
28
- )
24
+ from sagemaker .inputs import CreateModelInput , TrainingInput , TransformInput , FileSystemInput
29
25
from sagemaker .model import Model
30
26
from sagemaker .processing import (
31
27
ProcessingInput ,
@@ -145,7 +141,7 @@ def __init__(
145
141
self ,
146
142
name : str ,
147
143
estimator : EstimatorBase ,
148
- inputs : TrainingInput = None ,
144
+ inputs : Union [ TrainingInput , dict , str , FileSystemInput ] = None ,
149
145
cache_config : CacheConfig = None ,
150
146
depends_on : List [str ] = None ,
151
147
):
@@ -157,7 +153,23 @@ def __init__(
157
153
Args:
158
154
name (str): The name of the training step.
159
155
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
160
- inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
156
+ inputs (str or dict or sagemaker.inputs.TrainingInput
157
+ or sagemaker.inputs.FileSystemInput): Information
158
+ about the training data. This can be one of three types:
159
+
160
+ * (str) the S3 location where training data is saved, or a file:// path in
161
+ local mode.
162
+ * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) If using multiple
163
+ channels for training data, you can specify a dict mapping channel names to
164
+ strings or :func:`~sagemaker.inputs.TrainingInput` objects.
165
+ * (sagemaker.inputs.TrainingInput) - channel configuration for S3 data sources
166
+ that can provide additional information as well as the path to the training
167
+ dataset.
168
+ See :func:`sagemaker.inputs.TrainingInput` for full details.
169
+ * (sagemaker.inputs.FileSystemInput) - channel configuration for
170
+ a file system data source that can provide additional information as well as
171
+ the path to the training dataset.
172
+
161
173
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
162
174
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
163
175
depends on
0 commit comments