-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Allow custom output for RepackModelStep #2804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dde8d00
0f72907
0bae071
17fe93e
f0efd27
7a1f4f8
ee6afcf
faf4ad5
8210375
972a6d2
7206b9e
127c964
554d735
88e4d68
b3c19d8
fd7a335
ccfcbe7
71c5617
975e031
b377b52
9d259b3
b82fb8a
0489b59
ed9131b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -63,6 +63,7 @@ def __init__( | |||||||
estimator: EstimatorBase = None, | ||||||||
model_data=None, | ||||||||
depends_on: Union[List[str], List[Step]] = None, | ||||||||
repack_output_path=None, | ||||||||
repack_model_step_retry_policies: List[RetryPolicy] = None, | ||||||||
register_model_step_retry_policies: List[RetryPolicy] = None, | ||||||||
model_package_group_name=None, | ||||||||
|
@@ -92,6 +93,9 @@ def __init__( | |||||||
job can be run or on which an endpoint can be deployed (default: None). | ||||||||
depends_on (List[str] or List[Step]): The list of step names or step instances | ||||||||
the first step in the collection depends on | ||||||||
repack_output_path (str): The S3 prefix URI where the repacked model will be | ||||||||
uploaded (default: None) - don't include a trailing slash. | ||||||||
If not specified, the default location is s3://default-bucket/job-name. | ||||||||
repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies | ||||||||
for the repack model step | ||||||||
register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies | ||||||||
|
@@ -155,6 +159,7 @@ def __init__( | |||||||
security_group_ids=security_group_ids, | ||||||||
description=description, | ||||||||
display_name=display_name, | ||||||||
repack_output_path=repack_output_path, | ||||||||
**kwargs, | ||||||||
) | ||||||||
steps.append(repack_model_step) | ||||||||
|
@@ -199,6 +204,7 @@ def __init__( | |||||||
security_group_ids=security_group_ids, | ||||||||
description=description, | ||||||||
display_name=display_name, | ||||||||
repack_output_path=repack_output_path, | ||||||||
**kwargs, | ||||||||
) | ||||||||
steps.append(repack_model_step) | ||||||||
|
@@ -261,6 +267,7 @@ def __init__( | |||||||
image_uri=None, | ||||||||
predictor_cls=None, | ||||||||
env=None, | ||||||||
repack_output_path=None, | ||||||||
# transformer arguments | ||||||||
strategy=None, | ||||||||
assemble_with=None, | ||||||||
|
@@ -282,8 +289,8 @@ def __init__( | |||||||
|
||||||||
An estimator-centric step collection. It models what happens in workflows | ||||||||
when invoking the `transform()` method on an estimator instance: | ||||||||
First, if custom | ||||||||
model artifacts are required, a `_RepackModelStep` is included. | ||||||||
First, if a custom | ||||||||
entry point script is required, a `_RepackModelStep` is included. | ||||||||
Second, a | ||||||||
`CreateModelStep` with the model data passed in from a training step or other | ||||||||
training job output. | ||||||||
|
@@ -312,6 +319,9 @@ def __init__( | |||||||
it will be the format of the batch transform output. | ||||||||
env (dict): The Environment variables to be set for use during the | ||||||||
transform job (default: None). | ||||||||
repack_output_path (str): The S3 prefix URI where the repacked model will be | ||||||||
uploaded (default: None) - don't include a trailing slash. | ||||||||
If not specified, the default location is s3://default-bucket/job-name. | ||||||||
depends_on (List[str] or List[Step]): The list of step names or step instances | ||||||||
the first step in the collection depends on | ||||||||
repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies | ||||||||
|
@@ -322,10 +332,13 @@ def __init__( | |||||||
transform step | ||||||||
""" | ||||||||
steps = [] | ||||||||
repack_model = False | ||||||||
|
||||||||
if "entry_point" in kwargs: | ||||||||
entry_point = kwargs.get("entry_point", None) | ||||||||
source_dir = kwargs.get("source_dir", None) | ||||||||
dependencies = kwargs.get("dependencies", None) | ||||||||
repack_model = True | ||||||||
entry_point = kwargs.pop("entry_point", None) | ||||||||
source_dir = kwargs.pop("source_dir", None) | ||||||||
dependencies = kwargs.pop("dependencies", None) | ||||||||
repack_model_step = _RepackModelStep( | ||||||||
name=f"{name}RepackModel", | ||||||||
depends_on=depends_on, | ||||||||
|
@@ -341,6 +354,8 @@ def __init__( | |||||||
security_group_ids=estimator.security_group_ids, | ||||||||
description=description, | ||||||||
display_name=display_name, | ||||||||
repack_output_path=repack_output_path, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a
Suggested change
|
||||||||
**kwargs, | ||||||||
) | ||||||||
steps.append(repack_model_step) | ||||||||
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts | ||||||||
|
@@ -371,7 +386,7 @@ def predict_wrapper(endpoint, session): | |||||||
display_name=display_name, | ||||||||
retry_policies=model_step_retry_policies, | ||||||||
) | ||||||||
if "entry_point" not in kwargs and depends_on: | ||||||||
if not repack_model and depends_on: | ||||||||
# if the CreateModelStep is the first step in the collection | ||||||||
model_step.add_depends_on(depends_on) | ||||||||
steps.append(model_step) | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.