Skip to content

Commit 30d7e41

Browse files
authored
Merge branch 'master' into fix-na
2 parents 6dcb3bf + d28d478 commit 30d7e41

File tree

2 files changed

+71
-138
lines changed

2 files changed

+71
-138
lines changed

tests/integ/test_smdataparallel_tf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS,
3232
reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge",
3333
)
34+
@pytest.mark.skip("Failing due to bad DLC image release. Disable temporarily.")
3435
def test_smdataparallel_tf_mnist(
3536
sagemaker_session,
3637
tensorflow_training_latest_version,

tests/unit/sagemaker/lineage/test_visualizer.py

Lines changed: 70 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -49,64 +49,17 @@ def test_trial_component_name(viz, sagemaker_session):
4949
"TrialComponentArn": "tc-arn",
5050
}
5151

52-
sagemaker_session.sagemaker_client.list_associations.side_effect = [
53-
{
54-
"AssociationSummaries": [
55-
{
56-
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
57-
"SourceName": "source-name-1",
58-
"SourceType": "source-type-1",
59-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
60-
"DestinationName": "dest-name-1",
61-
"DestinationType": "dest-type-1",
62-
"AssociationType": "type-1",
63-
}
64-
]
65-
},
66-
{
67-
"AssociationSummaries": [
68-
{
69-
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
70-
"SourceName": "source-name-2",
71-
"SourceType": "source-type-2",
72-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
73-
"DestinationName": "dest-name-2",
74-
"DestinationType": "dest-type-2",
75-
"AssociationType": "type-2",
76-
}
77-
]
78-
},
79-
]
52+
get_list_associations_side_effect(sagemaker_session)
8053

8154
df = viz.show(trial_component_name=name)
8255

8356
sagemaker_session.sagemaker_client.describe_trial_component.assert_called_with(
8457
TrialComponentName=name,
8558
)
8659

87-
expected_calls = [
88-
unittest.mock.call(
89-
DestinationArn="tc-arn",
90-
),
91-
unittest.mock.call(
92-
SourceArn="tc-arn",
93-
),
94-
]
95-
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
60+
assert_list_associations_mock_calls(sagemaker_session)
9661

97-
expected_dataframe = pd.DataFrame.from_dict(
98-
OrderedDict(
99-
[
100-
("Name/Source", ["source-name-1", "dest-name-2"]),
101-
("Direction", ["Input", "Output"]),
102-
("Type", ["source-type-1", "dest-type-2"]),
103-
("Association Type", ["type-1", "type-2"]),
104-
("Lineage Type", ["artifact", "artifact"]),
105-
]
106-
)
107-
)
108-
109-
pd.testing.assert_frame_equal(expected_dataframe, df)
62+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
11063

11164

11265
def test_model_package_arn(viz, sagemaker_session):
@@ -116,34 +69,7 @@ def test_model_package_arn(viz, sagemaker_session):
11669
"ArtifactSummaries": [{"ArtifactArn": "artifact-arn"}]
11770
}
11871

119-
sagemaker_session.sagemaker_client.list_associations.side_effect = [
120-
{
121-
"AssociationSummaries": [
122-
{
123-
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
124-
"SourceName": "source-name-1",
125-
"SourceType": "source-type-1",
126-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
127-
"DestinationName": "dest-name-1",
128-
"DestinationType": "dest-type-1",
129-
"AssociationType": "type-1",
130-
}
131-
]
132-
},
133-
{
134-
"AssociationSummaries": [
135-
{
136-
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
137-
"SourceName": "source-name-2",
138-
"SourceType": "source-type-2",
139-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
140-
"DestinationName": "dest-name-2",
141-
"DestinationType": "dest-type-2",
142-
"AssociationType": "type-2",
143-
}
144-
]
145-
},
146-
]
72+
get_list_associations_side_effect(sagemaker_session)
14773

14874
df = viz.show(model_package_arn=name)
14975

@@ -161,19 +87,7 @@ def test_model_package_arn(viz, sagemaker_session):
16187
]
16288
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
16389

164-
expected_dataframe = pd.DataFrame.from_dict(
165-
OrderedDict(
166-
[
167-
("Name/Source", ["source-name-1", "dest-name-2"]),
168-
("Direction", ["Input", "Output"]),
169-
("Type", ["source-type-1", "dest-type-2"]),
170-
("Association Type", ["type-1", "type-2"]),
171-
("Lineage Type", ["artifact", "artifact"]),
172-
]
173-
)
174-
)
175-
176-
pd.testing.assert_frame_equal(expected_dataframe, df)
90+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
17791

17892

17993
def test_endpoint_arn(viz, sagemaker_session):
@@ -183,34 +97,7 @@ def test_endpoint_arn(viz, sagemaker_session):
18397
"ContextSummaries": [{"ContextArn": "context-arn"}]
18498
}
18599

186-
sagemaker_session.sagemaker_client.list_associations.side_effect = [
187-
{
188-
"AssociationSummaries": [
189-
{
190-
"SourceArn": "a:b:c:d:e:context/src-arn-1",
191-
"SourceName": "source-name-1",
192-
"SourceType": "source-type-1",
193-
"DestinationArn": "a:b:c:d:e:context/dest-arn-1",
194-
"DestinationName": "dest-name-1",
195-
"DestinationType": "dest-type-1",
196-
"AssociationType": "type-1",
197-
}
198-
]
199-
},
200-
{
201-
"AssociationSummaries": [
202-
{
203-
"SourceArn": "a:b:c:d:e:context/src-arn-2",
204-
"SourceName": "source-name-2",
205-
"SourceType": "source-type-2",
206-
"DestinationArn": "a:b:c:d:e:context/dest-arn-2",
207-
"DestinationName": "dest-name-2",
208-
"DestinationType": "dest-type-2",
209-
"AssociationType": "type-2",
210-
}
211-
]
212-
},
213-
]
100+
get_list_associations_side_effect(sagemaker_session)
214101

215102
df = viz.show(endpoint_arn=name)
216103

@@ -228,27 +115,74 @@ def test_endpoint_arn(viz, sagemaker_session):
228115
]
229116
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
230117

231-
expected_dataframe = pd.DataFrame.from_dict(
232-
OrderedDict(
233-
[
234-
("Name/Source", ["source-name-1", "dest-name-2"]),
235-
("Direction", ["Input", "Output"]),
236-
("Type", ["source-type-1", "dest-type-2"]),
237-
("Association Type", ["type-1", "type-2"]),
238-
("Lineage Type", ["context", "context"]),
239-
]
240-
)
118+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
119+
120+
121+
def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
122+
123+
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
124+
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
125+
}
126+
127+
get_list_associations_side_effect(sagemaker_session)
128+
129+
step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
130+
131+
df = viz.show(pipeline_execution_step=step)
132+
133+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
134+
SourceArn="proc-job-arn",
241135
)
242136

243-
pd.testing.assert_frame_equal(expected_dataframe, df)
137+
assert_list_associations_mock_calls(sagemaker_session)
244138

139+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
245140

246-
def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
141+
142+
def test_training_job_pipeline_execution_step(viz, sagemaker_session):
247143

248144
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
249145
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
250146
}
251147

148+
get_list_associations_side_effect(sagemaker_session)
149+
150+
step = {"Metadata": {"TrainingJob": {"Arn": "training-job-arn"}}}
151+
152+
df = viz.show(pipeline_execution_step=step)
153+
154+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
155+
SourceArn="training-job-arn",
156+
)
157+
158+
assert_list_associations_mock_calls(sagemaker_session)
159+
160+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
161+
162+
163+
def test_transform_job_pipeline_execution_step(viz, sagemaker_session):
164+
165+
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
166+
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
167+
}
168+
169+
get_list_associations_side_effect(sagemaker_session)
170+
171+
step = {"Metadata": {"TransformJob": {"Arn": "transform-job-arn"}}}
172+
173+
df = viz.show(pipeline_execution_step=step)
174+
175+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
176+
SourceArn="transform-job-arn",
177+
)
178+
179+
assert_list_associations_mock_calls(sagemaker_session)
180+
181+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
182+
183+
184+
def get_list_associations_side_effect(sagemaker_session):
185+
252186
sagemaker_session.sagemaker_client.list_associations.side_effect = [
253187
{
254188
"AssociationSummaries": [
@@ -278,13 +212,8 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
278212
},
279213
]
280214

281-
step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
282-
283-
df = viz.show(pipeline_execution_step=step)
284215

285-
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
286-
SourceArn="proc-job-arn",
287-
)
216+
def assert_list_associations_mock_calls(sagemaker_session):
288217

289218
expected_calls = [
290219
unittest.mock.call(
@@ -296,6 +225,9 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
296225
]
297226
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
298227

228+
229+
def get_expected_dataframe():
230+
299231
expected_dataframe = pd.DataFrame.from_dict(
300232
OrderedDict(
301233
[
@@ -308,4 +240,4 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
308240
)
309241
)
310242

311-
pd.testing.assert_frame_equal(expected_dataframe, df)
243+
return expected_dataframe

0 commit comments

Comments
 (0)