17
17
import shutil
18
18
import tarfile
19
19
import tempfile
20
- from typing import List , Union , Optional
20
+ from typing import List , Union , Optional , TYPE_CHECKING
21
21
from sagemaker import image_uris
22
22
from sagemaker .inputs import TrainingInput
23
23
from sagemaker .estimator import EstimatorBase
34
34
from sagemaker .utils import _save_model , download_file_from_url
35
35
from sagemaker .workflow .retry import RetryPolicy
36
36
37
+ if TYPE_CHECKING :
38
+ from sagemaker .workflow .step_collections import StepCollection
39
+
37
40
FRAMEWORK_VERSION = "0.23-1"
38
41
INSTANCE_TYPE = "ml.m5.large"
39
42
REPACK_SCRIPT = "_repack_model.py"
@@ -57,7 +60,7 @@ def __init__(
57
60
description : str = None ,
58
61
source_dir : str = None ,
59
62
dependencies : List = None ,
60
- depends_on : Union [List [str ], List [ Step ]] = None ,
63
+ depends_on : Optional [List [Union [ str , Step , "StepCollection" ] ]] = None ,
61
64
retry_policies : List [RetryPolicy ] = None ,
62
65
subnets = None ,
63
66
security_group_ids = None ,
@@ -124,8 +127,9 @@ def __init__(
124
127
>>> |------ virtual-env
125
128
126
129
This is not supported with "local code" in Local Mode.
127
- depends_on (List[str] or List[Step]): A list of step names or instances
128
- this step depends on (default: None).
130
+ depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection`
131
+ names or `Step` instances or `StepCollection` instances that the current `Step`
132
+ depends on (default: None).
129
133
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
130
134
(default: None).
131
135
subnets (list[str]): List of subnet ids. If not specified, the re-packing
@@ -274,7 +278,7 @@ def __init__(
274
278
compile_model_family = None ,
275
279
display_name : str = None ,
276
280
description = None ,
277
- depends_on : Optional [Union [ List [str ], List [ Step ]]] = None ,
281
+ depends_on : Optional [List [Union [ str , Step , "StepCollection" ]]] = None ,
278
282
retry_policies : Optional [List [RetryPolicy ]] = None ,
279
283
tags = None ,
280
284
container_def_list = None ,
@@ -311,8 +315,9 @@ def __init__(
311
315
if specified, a compiled model will be used (default: None).
312
316
display_name (str): The display name of this `_RegisterModelStep` step (default: None).
313
317
description (str): Model Package description (default: None).
314
- depends_on (List[str] or List[Step]): A list of step names or instances
315
- this step depends on (default: None).
318
+ depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection`
319
+ names or `Step` instances or `StepCollection` instances that the current `Step`
320
+ depends on (default: None).
316
321
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
317
322
(default: None).
318
323
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs used to
0 commit comments