Skip to content

Commit 5255a75

Browse files
committed
reformatting
1 parent b84a818 commit 5255a75

File tree

2 files changed

+146
-52
lines changed

2 files changed

+146
-52
lines changed

src/sagemaker/spark/processing.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,13 +500,19 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
500500
destination=f"{self._conf_container_base_path}{input_channel_name}",
501501
input_name=input_channel_name,
502502
)
503-
spark_opt = Join(on=",", values=spark_opt_s3_uris + [input_channel.destination]) \
504-
if spark_opt_s3_uris_has_pipeline_var else ",".join(spark_opt_s3_uris + [input_channel.destination])
503+
spark_opt = (
504+
Join(on=",", values=spark_opt_s3_uris + [input_channel.destination])
505+
if spark_opt_s3_uris_has_pipeline_var
506+
else ",".join(spark_opt_s3_uris + [input_channel.destination])
507+
)
505508
# If no local files were uploaded, form the spark-submit option from a list of S3 URIs
506509
else:
507510
input_channel = None
508-
spark_opt = Join(on=",", values=spark_opt_s3_uris) if spark_opt_s3_uris_has_pipeline_var \
511+
spark_opt = (
512+
Join(on=",", values=spark_opt_s3_uris)
513+
if spark_opt_s3_uris_has_pipeline_var
509514
else ",".join(spark_opt_s3_uris)
515+
)
510516

511517
return input_channel, spark_opt
512518

@@ -1081,7 +1087,7 @@ def run(
10811087
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
10821088
inputs: Optional[List[ProcessingInput]] = None,
10831089
outputs: Optional[List[ProcessingOutput]] = None,
1084-
arguments:Optional[List[Union[str, PipelineVariable]]] = None,
1090+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
10851091
wait: bool = True,
10861092
logs: bool = True,
10871093
job_name: Optional[str] = None,

tests/unit/sagemaker/workflow/test_processing_step.py

Lines changed: 136 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
from sagemaker.transformer import Transformer
2424
from sagemaker.tuner import HyperparameterTuner
2525
from sagemaker.workflow.pipeline_context import PipelineSession
26-
from sagemaker.workflow.functions import Join
27-
from sagemaker.workflow.parameters import ParameterString
2826

2927
from sagemaker.processing import (
3028
Processor,
@@ -603,8 +601,10 @@ def test_insert_wrong_step_args_into_processing_step(inputs, pipeline_session):
603601
)
604602

605603

606-
@pytest.mark.parametrize("spark_processor", [
607-
(
604+
@pytest.mark.parametrize(
605+
"spark_processor",
606+
[
607+
(
608608
SparkJarProcessor(
609609
role=ROLE,
610610
framework_version="2.4",
@@ -614,14 +614,28 @@ def test_insert_wrong_step_args_into_processing_step(inputs, pipeline_session):
614614
{
615615
"submit_app": "s3://my-jar",
616616
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
617-
"arguments": ["--input", "input-data-uri", "--output", ParameterString("MyArgOutput")],
618-
"submit_jars": ["s3://my-jar", ParameterString("MyJars"), "s3://her-jar", ParameterString("OurJar")],
619-
"submit_files": ["s3://my-files", ParameterString("MyFiles"), "s3://her-files",
620-
ParameterString("OurFiles")],
621-
"spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri")
617+
"arguments": [
618+
"--input",
619+
"input-data-uri",
620+
"--output",
621+
ParameterString("MyArgOutput"),
622+
],
623+
"submit_jars": [
624+
"s3://my-jar",
625+
ParameterString("MyJars"),
626+
"s3://her-jar",
627+
ParameterString("OurJar"),
628+
],
629+
"submit_files": [
630+
"s3://my-files",
631+
ParameterString("MyFiles"),
632+
"s3://her-files",
633+
ParameterString("OurFiles"),
634+
],
635+
"spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"),
622636
},
623-
),
624-
(
637+
),
638+
(
625639
PySparkProcessor(
626640
role=ROLE,
627641
framework_version="2.4",
@@ -630,16 +644,35 @@ def test_insert_wrong_step_args_into_processing_step(inputs, pipeline_session):
630644
),
631645
{
632646
"submit_app": "s3://my-jar",
633-
"arguments": ["--input", "input-data-uri", "--output", ParameterString("MyArgOutput")],
634-
"submit_py_files": ["s3://my-py-files", ParameterString("MyPyFiles"), "s3://her-pyfiles",
635-
ParameterString("OurPyFiles")],
636-
"submit_jars": ["s3://my-jar", ParameterString("MyJars"), "s3://her-jar", ParameterString("OurJar")],
637-
"submit_files": ["s3://my-files", ParameterString("MyFiles"), "s3://her-files",
638-
ParameterString("OurFiles")],
647+
"arguments": [
648+
"--input",
649+
"input-data-uri",
650+
"--output",
651+
ParameterString("MyArgOutput"),
652+
],
653+
"submit_py_files": [
654+
"s3://my-py-files",
655+
ParameterString("MyPyFiles"),
656+
"s3://her-pyfiles",
657+
ParameterString("OurPyFiles"),
658+
],
659+
"submit_jars": [
660+
"s3://my-jar",
661+
ParameterString("MyJars"),
662+
"s3://her-jar",
663+
ParameterString("OurJar"),
664+
],
665+
"submit_files": [
666+
"s3://my-files",
667+
ParameterString("MyFiles"),
668+
"s3://her-files",
669+
ParameterString("OurFiles"),
670+
],
639671
"spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"),
640672
},
641-
),
642-
])
673+
),
674+
],
675+
)
643676
def test_spark_processor(spark_processor, processing_input, pipeline_session):
644677

645678
processor, run_inputs = spark_processor
@@ -656,52 +689,108 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
656689

657690
step_args = step_args.args
658691

659-
assert step_args['AppSpecification']['ContainerArguments'] == run_inputs['arguments']
692+
assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"]
660693

661-
entry_points = step_args['AppSpecification']['ContainerEntrypoint']
694+
entry_points = step_args["AppSpecification"]["ContainerEntrypoint"]
662695
entry_points_expr = []
663696
for entry_point in entry_points:
664697
if is_pipeline_variable(entry_point):
665698
entry_points_expr.append(entry_point.expr)
666699
else:
667700
entry_points_expr.append(entry_point)
668701

669-
if 'submit_py_files' in run_inputs:
702+
if "submit_py_files" in run_inputs:
670703
expected = [
671-
'smspark-submit',
672-
'--py-files', {'Std:Join': {'On': ',', 'Values': ['s3://my-py-files', {'Get': 'Parameters.MyPyFiles'},
673-
's3://her-pyfiles', {'Get': 'Parameters.OurPyFiles'}]}},
674-
'--jars', {'Std:Join': {'On': ',', 'Values': ['s3://my-jar', {'Get': 'Parameters.MyJars'}, 's3://her-jar', {'Get': 'Parameters.OurJar'}]}},
675-
'--files', {'Std:Join': {'On': ',', 'Values': ['s3://my-files', {'Get': 'Parameters.MyFiles'}, 's3://her-files', {'Get': 'Parameters.OurFiles'}]}},
676-
'--local-spark-event-logs-dir', '/opt/ml/processing/spark-events/', '/opt/ml/processing/input/code'
704+
"smspark-submit",
705+
"--py-files",
706+
{
707+
"Std:Join": {
708+
"On": ",",
709+
"Values": [
710+
"s3://my-py-files",
711+
{"Get": "Parameters.MyPyFiles"},
712+
"s3://her-pyfiles",
713+
{"Get": "Parameters.OurPyFiles"},
714+
],
715+
}
716+
},
717+
"--jars",
718+
{
719+
"Std:Join": {
720+
"On": ",",
721+
"Values": [
722+
"s3://my-jar",
723+
{"Get": "Parameters.MyJars"},
724+
"s3://her-jar",
725+
{"Get": "Parameters.OurJar"},
726+
],
727+
}
728+
},
729+
"--files",
730+
{
731+
"Std:Join": {
732+
"On": ",",
733+
"Values": [
734+
"s3://my-files",
735+
{"Get": "Parameters.MyFiles"},
736+
"s3://her-files",
737+
{"Get": "Parameters.OurFiles"},
738+
],
739+
}
740+
},
741+
"--local-spark-event-logs-dir",
742+
"/opt/ml/processing/spark-events/",
743+
"/opt/ml/processing/input/code",
677744
]
678745
# py spark
679746
else:
680747
expected = [
681-
'smspark-submit',
682-
'--class',
683-
'com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp',
684-
'--jars', {'Std:Join': {'On': ',', 'Values': ['s3://my-jar', {'Get': 'Parameters.MyJars'},
685-
's3://her-jar', {'Get': 'Parameters.OurJar'}]}},
686-
'--files', {'Std:Join': {'On': ',', 'Values': ['s3://my-files', {'Get': 'Parameters.MyFiles'},
687-
's3://her-files', {'Get': 'Parameters.OurFiles'}]}},
688-
'--local-spark-event-logs-dir', '/opt/ml/processing/spark-events/', '/opt/ml/processing/input/code'
748+
"smspark-submit",
749+
"--class",
750+
"com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
751+
"--jars",
752+
{
753+
"Std:Join": {
754+
"On": ",",
755+
"Values": [
756+
"s3://my-jar",
757+
{"Get": "Parameters.MyJars"},
758+
"s3://her-jar",
759+
{"Get": "Parameters.OurJar"},
760+
],
761+
}
762+
},
763+
"--files",
764+
{
765+
"Std:Join": {
766+
"On": ",",
767+
"Values": [
768+
"s3://my-files",
769+
{"Get": "Parameters.MyFiles"},
770+
"s3://her-files",
771+
{"Get": "Parameters.OurFiles"},
772+
],
773+
}
774+
},
775+
"--local-spark-event-logs-dir",
776+
"/opt/ml/processing/spark-events/",
777+
"/opt/ml/processing/input/code",
689778
]
690779

691780
assert entry_points_expr == expected
692-
for output in step_args['ProcessingOutputConfig']['Outputs']:
693-
if is_pipeline_variable(output['S3Output']['S3Uri']):
694-
output['S3Output']['S3Uri'] = output['S3Output']['S3Uri'].expr
781+
for output in step_args["ProcessingOutputConfig"]["Outputs"]:
782+
if is_pipeline_variable(output["S3Output"]["S3Uri"]):
783+
output["S3Output"]["S3Uri"] = output["S3Output"]["S3Uri"].expr
695784

696-
assert step_args['ProcessingOutputConfig']['Outputs'] == [
785+
assert step_args["ProcessingOutputConfig"]["Outputs"] == [
697786
{
698-
'OutputName': 'output-1',
699-
'AppManaged': False,
700-
'S3Output': {
701-
'S3Uri': {'Get': 'Parameters.MySparkEventLogS3Uri'},
702-
'LocalPath': '/opt/ml/processing/spark-events/',
703-
'S3UploadMode': 'Continuous'
704-
}
787+
"OutputName": "output-1",
788+
"AppManaged": False,
789+
"S3Output": {
790+
"S3Uri": {"Get": "Parameters.MySparkEventLogS3Uri"},
791+
"LocalPath": "/opt/ml/processing/spark-events/",
792+
"S3UploadMode": "Continuous",
793+
},
705794
}
706795
]
707796

@@ -711,4 +800,3 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
711800
sagemaker_session=pipeline_session,
712801
)
713802
pipeline.definition()
714-

0 commit comments

Comments
 (0)