Skip to content

documentation: Correct type annotation for TrainingStep inputs #2468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 18, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@
import abc

from enum import Enum
from typing import Dict, List
from typing import Dict, List, Union

import attr

from sagemaker.estimator import EstimatorBase, _TrainingJob
from sagemaker.inputs import (
CreateModelInput,
TrainingInput,
TransformInput,
)
from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput
from sagemaker.model import Model
from sagemaker.processing import (
ProcessingInput,
Expand Down Expand Up @@ -145,7 +141,7 @@ def __init__(
self,
name: str,
estimator: EstimatorBase,
inputs: TrainingInput = None,
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
cache_config: CacheConfig = None,
depends_on: List[str] = None,
):
Expand All @@ -157,7 +153,23 @@ def __init__(
Args:
name (str): The name of the training step.
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
inputs (str or dict or sagemaker.inputs.TrainingInput
or sagemaker.inputs.FileSystemInput): Information
about the training data. This can be one of three types:

* (str) the S3 location where training data is saved, or a file:// path in
local mode.
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) If using multiple
channels for training data, you can specify a dict mapping channel names to
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
* (sagemaker.inputs.TrainingInput) - channel configuration for S3 data sources
that can provide additional information as well as the path to the training
dataset.
See :func:`sagemaker.inputs.TrainingInput` for full details.
* (sagemaker.inputs.FileSystemInput) - channel configuration for
a file system data source that can provide additional information as well as
the path to the training dataset.

cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
depends on
Expand Down