Skip to content

Commit 1a5a448

Browse files
authored
Merge branch 'master' into feat/enhance-bucket-override-support
2 parents fc1f55b + a11e299 commit 1a5a448

File tree

9 files changed

+513
-182
lines changed

9 files changed

+513
-182
lines changed

src/sagemaker/spark/processing.py

Lines changed: 91 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,20 @@
3131
from io import BytesIO
3232
from urllib.parse import urlparse
3333

34+
from typing import Union, List, Dict, Optional
35+
3436
from sagemaker import image_uris
3537
from sagemaker.local.image import _ecr_login_if_needed, _pull_image
3638
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
3739
from sagemaker.s3 import S3Uploader
3840
from sagemaker.session import Session
41+
from sagemaker.network import NetworkConfig
3942
from sagemaker.spark import defaults
4043

44+
from sagemaker.workflow import is_pipeline_variable
45+
from sagemaker.workflow.entities import PipelineVariable
46+
from sagemaker.workflow.functions import Join
47+
4148
logger = logging.getLogger(__name__)
4249

4350

@@ -249,6 +256,12 @@ def run(
249256
"""
250257
self._current_job_name = self._generate_current_job_name(job_name=job_name)
251258

259+
if is_pipeline_variable(submit_app):
260+
raise ValueError(
261+
"submit_app argument has to be a valid S3 URI or local file path "
262+
+ "rather than a pipeline variable"
263+
)
264+
252265
return super().run(
253266
submit_app,
254267
inputs,
@@ -437,9 +450,14 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
437450

438451
use_input_channel = False
439452
spark_opt_s3_uris = []
453+
spark_opt_s3_uris_has_pipeline_var = False
440454

441455
with tempfile.TemporaryDirectory() as tmpdir:
442456
for dep_path in submit_deps:
457+
if is_pipeline_variable(dep_path):
458+
spark_opt_s3_uris.append(dep_path)
459+
spark_opt_s3_uris_has_pipeline_var = True
460+
continue
443461
dep_url = urlparse(dep_path)
444462
# S3 URIs are included as-is in the spark-submit argument
445463
if dep_url.scheme in ["s3", "s3a"]:
@@ -482,11 +500,19 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
482500
destination=f"{self._conf_container_base_path}{input_channel_name}",
483501
input_name=input_channel_name,
484502
)
485-
spark_opt = ",".join(spark_opt_s3_uris + [input_channel.destination])
503+
spark_opt = (
504+
Join(on=",", values=spark_opt_s3_uris + [input_channel.destination])
505+
if spark_opt_s3_uris_has_pipeline_var
506+
else ",".join(spark_opt_s3_uris + [input_channel.destination])
507+
)
486508
# If no local files were uploaded, form the spark-submit option from a list of S3 URIs
487509
else:
488510
input_channel = None
489-
spark_opt = ",".join(spark_opt_s3_uris)
511+
spark_opt = (
512+
Join(on=",", values=spark_opt_s3_uris)
513+
if spark_opt_s3_uris_has_pipeline_var
514+
else ",".join(spark_opt_s3_uris)
515+
)
490516

491517
return input_channel, spark_opt
492518

@@ -592,6 +618,9 @@ def _validate_s3_uri(self, spark_output_s3_path):
592618
Args:
593619
spark_output_s3_path (str): The URI of the Spark output S3 Path.
594620
"""
621+
if is_pipeline_variable(spark_output_s3_path):
622+
return
623+
595624
if urlparse(spark_output_s3_path).scheme != "s3":
596625
raise ValueError(
597626
f"Invalid s3 path: {spark_output_s3_path}. Please enter something like "
@@ -650,22 +679,22 @@ class PySparkProcessor(_SparkProcessorBase):
650679

651680
def __init__(
652681
self,
653-
role,
654-
instance_type,
655-
instance_count,
656-
framework_version=None,
657-
py_version=None,
658-
container_version=None,
659-
image_uri=None,
660-
volume_size_in_gb=30,
661-
volume_kms_key=None,
662-
output_kms_key=None,
663-
max_runtime_in_seconds=None,
664-
base_job_name=None,
665-
sagemaker_session=None,
666-
env=None,
667-
tags=None,
668-
network_config=None,
682+
role: str,
683+
instance_type: Union[int, PipelineVariable],
684+
instance_count: Union[str, PipelineVariable],
685+
framework_version: Optional[str] = None,
686+
py_version: Optional[str] = None,
687+
container_version: Optional[str] = None,
688+
image_uri: Optional[Union[str, PipelineVariable]] = None,
689+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
690+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
691+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
692+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
693+
base_job_name: Optional[str] = None,
694+
sagemaker_session: Optional[Session] = None,
695+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
696+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
697+
network_config: Optional[NetworkConfig] = None,
669698
):
670699
"""Initialize an ``PySparkProcessor`` instance.
671700
@@ -795,20 +824,20 @@ def get_run_args(
795824

796825
def run(
797826
self,
798-
submit_app,
799-
submit_py_files=None,
800-
submit_jars=None,
801-
submit_files=None,
802-
inputs=None,
803-
outputs=None,
804-
arguments=None,
805-
wait=True,
806-
logs=True,
807-
job_name=None,
808-
experiment_config=None,
809-
configuration=None,
810-
spark_event_logs_s3_uri=None,
811-
kms_key=None,
827+
submit_app: str,
828+
submit_py_files: Optional[List[Union[str, PipelineVariable]]] = None,
829+
submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
830+
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
831+
inputs: Optional[List[ProcessingInput]] = None,
832+
outputs: Optional[List[ProcessingOutput]] = None,
833+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
834+
wait: bool = True,
835+
logs: bool = True,
836+
job_name: Optional[str] = None,
837+
experiment_config: Optional[Dict[str, str]] = None,
838+
configuration: Optional[Union[List[Dict], Dict]] = None,
839+
spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
840+
kms_key: Optional[str] = None,
812841
):
813842
"""Runs a processing job.
814843
@@ -907,22 +936,22 @@ class SparkJarProcessor(_SparkProcessorBase):
907936

908937
def __init__(
909938
self,
910-
role,
911-
instance_type,
912-
instance_count,
913-
framework_version=None,
914-
py_version=None,
915-
container_version=None,
916-
image_uri=None,
917-
volume_size_in_gb=30,
918-
volume_kms_key=None,
919-
output_kms_key=None,
920-
max_runtime_in_seconds=None,
921-
base_job_name=None,
922-
sagemaker_session=None,
923-
env=None,
924-
tags=None,
925-
network_config=None,
939+
role: str,
940+
instance_type: Union[int, PipelineVariable],
941+
instance_count: Union[str, PipelineVariable],
942+
framework_version: Optional[str] = None,
943+
py_version: Optional[str] = None,
944+
container_version: Optional[str] = None,
945+
image_uri: Optional[Union[str, PipelineVariable]] = None,
946+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
947+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
948+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
949+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
950+
base_job_name: Optional[str] = None,
951+
sagemaker_session: Optional[Session] = None,
952+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
953+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
954+
network_config: Optional[NetworkConfig] = None,
926955
):
927956
"""Initialize a ``SparkJarProcessor`` instance.
928957
@@ -1052,20 +1081,20 @@ def get_run_args(
10521081

10531082
def run(
10541083
self,
1055-
submit_app,
1056-
submit_class=None,
1057-
submit_jars=None,
1058-
submit_files=None,
1059-
inputs=None,
1060-
outputs=None,
1061-
arguments=None,
1062-
wait=True,
1063-
logs=True,
1064-
job_name=None,
1065-
experiment_config=None,
1066-
configuration=None,
1067-
spark_event_logs_s3_uri=None,
1068-
kms_key=None,
1084+
submit_app: str,
1085+
submit_class: Union[str, PipelineVariable],
1086+
submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
1087+
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
1088+
inputs: Optional[List[ProcessingInput]] = None,
1089+
outputs: Optional[List[ProcessingOutput]] = None,
1090+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
1091+
wait: bool = True,
1092+
logs: bool = True,
1093+
job_name: Optional[str] = None,
1094+
experiment_config: Optional[Dict[str, str]] = None,
1095+
configuration: Optional[Union[List[Dict], Dict]] = None,
1096+
spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
1097+
kms_key: Optional[str] = None,
10691098
):
10701099
"""Runs a processing job.
10711100

src/sagemaker/workflow/condition_step.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,17 @@
1515

1616
from typing import List, Union, Optional
1717

18-
import attr
1918

2019
from sagemaker.deprecations import deprecated_class
2120
from sagemaker.workflow.conditions import Condition
2221
from sagemaker.workflow.step_collections import StepCollection
22+
from sagemaker.workflow.functions import JsonGet as NewJsonGet
2323
from sagemaker.workflow.steps import (
2424
Step,
2525
StepTypeEnum,
2626
)
2727
from sagemaker.workflow.utilities import list_to_request
28-
from sagemaker.workflow.entities import (
29-
RequestType,
30-
PipelineVariable,
31-
)
28+
from sagemaker.workflow.entities import RequestType
3229
from sagemaker.workflow.properties import (
3330
Properties,
3431
PropertyFile,
@@ -93,16 +90,15 @@ def arguments(self) -> RequestType:
9390
@property
9491
def step_only_arguments(self):
9592
"""Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`."""
96-
return self.conditions
93+
return [condition.to_request() for condition in self.conditions]
9794

9895
@property
9996
def properties(self):
10097
"""A simple Properties object with `Outcome` as the only property"""
10198
return self._properties
10299

103100

104-
@attr.s
105-
class JsonGet(PipelineVariable): # pragma: no cover
101+
class JsonGet(NewJsonGet): # pragma: no cover
106102
"""Get JSON properties from PropertyFiles.
107103
108104
Attributes:
@@ -112,28 +108,8 @@ class JsonGet(PipelineVariable): # pragma: no cover
112108
json_path (str): The JSON path expression to the requested value.
113109
"""
114110

115-
step: Step = attr.ib()
116-
property_file: Union[PropertyFile, str] = attr.ib()
117-
json_path: str = attr.ib()
118-
119-
@property
120-
def expr(self):
121-
"""The expression dict for a `JsonGet` function."""
122-
if isinstance(self.property_file, PropertyFile):
123-
name = self.property_file.name
124-
else:
125-
name = self.property_file
126-
return {
127-
"Std:JsonGet": {
128-
"PropertyFile": {"Get": f"Steps.{self.step.name}.PropertyFiles.{name}"},
129-
"Path": self.json_path,
130-
}
131-
}
132-
133-
@property
134-
def _referenced_steps(self) -> List[str]:
135-
"""List of step names that this function depends on."""
136-
return [self.step.name]
111+
def __init__(self, step: Step, property_file: Union[PropertyFile, str], json_path: str):
112+
super().__init__(step_name=step.name, property_file=property_file, json_path=json_path)
137113

138114

139115
JsonGet = deprecated_class(JsonGet, "JsonGet")

src/sagemaker/workflow/conditions.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
import abc
2121

2222
from enum import Enum
23-
from typing import Dict, List, Union
23+
from typing import List, Union
2424

2525
import attr
2626

27-
from sagemaker.workflow import is_pipeline_variable
2827
from sagemaker.workflow.entities import (
2928
DefaultEnumMeta,
3029
Entity,
31-
Expression,
3230
PrimitiveType,
3331
RequestType,
3432
)
@@ -88,8 +86,8 @@ def to_request(self) -> RequestType:
8886
"""Get the request structure for workflow service calls."""
8987
return {
9088
"Type": self.condition_type.value,
91-
"LeftValue": primitive_or_expr(self.left),
92-
"RightValue": primitive_or_expr(self.right),
89+
"LeftValue": self.left,
90+
"RightValue": self.right,
9391
}
9492

9593
@property
@@ -227,8 +225,8 @@ def to_request(self) -> RequestType:
227225
"""Get the request structure for workflow service calls."""
228226
return {
229227
"Type": self.condition_type.value,
230-
"QueryValue": self.value.expr,
231-
"Values": [primitive_or_expr(in_value) for in_value in self.in_values],
228+
"QueryValue": self.value,
229+
"Values": self.in_values,
232230
}
233231

234232
@property
@@ -291,19 +289,3 @@ def _referenced_steps(self) -> List[str]:
291289
for condition in self.conditions:
292290
steps.extend(condition._referenced_steps)
293291
return steps
294-
295-
296-
def primitive_or_expr(
297-
value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties]
298-
) -> Union[Dict[str, str], PrimitiveType]:
299-
"""Provide the expression of the value or return value if it is a primitive.
300-
301-
Args:
302-
value (Union[ConditionValueType, PrimitiveType]): The value to evaluate.
303-
304-
Returns:
305-
Either the expression of the value or the primitive value.
306-
"""
307-
if is_pipeline_variable(value):
308-
return value.expr
309-
return value

src/sagemaker/workflow/step_collections.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def properties(self):
5757
class RegisterModel(StepCollection): # pragma: no cover
5858
"""Register Model step collection for workflow."""
5959

60+
_REGISTER_MODEL_NAME_BASE = "RegisterModel"
61+
_REPACK_MODEL_NAME_BASE = "RepackModel"
62+
6063
def __init__(
6164
self,
6265
name: str,
@@ -168,7 +171,7 @@ def __init__(
168171
kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None))
169172

170173
repack_model_step = _RepackModelStep(
171-
name=f"{name}RepackModel",
174+
name="{}-{}".format(self.name, self._REPACK_MODEL_NAME_BASE),
172175
depends_on=depends_on,
173176
retry_policies=repack_model_step_retry_policies,
174177
sagemaker_session=estimator.sagemaker_session,
@@ -212,7 +215,7 @@ def __init__(
212215
model_name = model_entity.name or model_entity._framework_name
213216

214217
repack_model_step = _RepackModelStep(
215-
name=f"{model_name}RepackModel",
218+
name="{}-{}".format(model_name, self._REPACK_MODEL_NAME_BASE),
216219
depends_on=depends_on,
217220
retry_policies=repack_model_step_retry_policies,
218221
sagemaker_session=sagemaker_session,
@@ -256,7 +259,7 @@ def __init__(
256259
)
257260

258261
register_model_step = _RegisterModelStep(
259-
name=name,
262+
name="{}-{}".format(self.name, self._REGISTER_MODEL_NAME_BASE),
260263
estimator=estimator,
261264
model_data=model_data,
262265
content_types=content_types,

0 commit comments

Comments
 (0)