Skip to content

Commit bd8f023

Browse files
stisacahsan-z-khan
andauthored
Add tests for visualizer to improve test coverage (#2161)
Co-authored-by: Ahsan Khan <[email protected]>
1 parent 536ba56 commit bd8f023

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed

tests/unit/sagemaker/lineage/test_visualizer.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,137 @@ def test_trial_component_name(sagemaker_session, vizualizer):
119119
)
120120

121121
pd.testing.assert_frame_equal(expected_dataframe, df)
122+
123+
124+
def test_model_package_arn(sagemaker_session, vizualizer):
125+
name = "model_package_arn"
126+
127+
sagemaker_session.sagemaker_client.list_artifacts.return_value = {
128+
"ArtifactSummaries": [{"ArtifactArn": "artifact-arn"}]
129+
}
130+
131+
sagemaker_session.sagemaker_client.list_associations.side_effect = [
132+
{
133+
"AssociationSummaries": [
134+
{
135+
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
136+
"SourceName": "source-name-1",
137+
"SourceType": "source-type-1",
138+
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
139+
"DestinationName": "dest-name-1",
140+
"DestinationType": "dest-type-1",
141+
"AssociationType": "type-1",
142+
}
143+
]
144+
},
145+
{
146+
"AssociationSummaries": [
147+
{
148+
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
149+
"SourceName": "source-name-2",
150+
"SourceType": "source-type-2",
151+
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
152+
"DestinationName": "dest-name-2",
153+
"DestinationType": "dest-type-2",
154+
"AssociationType": "type-2",
155+
}
156+
]
157+
},
158+
]
159+
160+
df = vizualizer.show(model_package_arn=name)
161+
162+
sagemaker_session.sagemaker_client.list_artifacts.assert_called_with(
163+
SourceUri=name,
164+
)
165+
166+
expected_calls = [
167+
unittest.mock.call(
168+
DestinationArn="artifact-arn",
169+
),
170+
unittest.mock.call(
171+
SourceArn="artifact-arn",
172+
),
173+
]
174+
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
175+
176+
expected_dataframe = pd.DataFrame.from_dict(
177+
OrderedDict(
178+
[
179+
("Name/Source", ["source-name-1", "dest-name-2"]),
180+
("Direction", ["Input", "Output"]),
181+
("Type", ["source-type-1", "dest-type-2"]),
182+
("Association Type", ["type-1", "type-2"]),
183+
("Lineage Type", ["artifact", "artifact"]),
184+
]
185+
)
186+
)
187+
188+
pd.testing.assert_frame_equal(expected_dataframe, df)
189+
190+
191+
def test_endpoint_arn(sagemaker_session, vizualizer):
192+
name = "endpoint_arn"
193+
194+
sagemaker_session.sagemaker_client.list_contexts.return_value = {
195+
"ContextSummaries": [{"ContextArn": "context-arn"}]
196+
}
197+
198+
sagemaker_session.sagemaker_client.list_associations.side_effect = [
199+
{
200+
"AssociationSummaries": [
201+
{
202+
"SourceArn": "a:b:c:d:e:context/src-arn-1",
203+
"SourceName": "source-name-1",
204+
"SourceType": "source-type-1",
205+
"DestinationArn": "a:b:c:d:e:context/dest-arn-1",
206+
"DestinationName": "dest-name-1",
207+
"DestinationType": "dest-type-1",
208+
"AssociationType": "type-1",
209+
}
210+
]
211+
},
212+
{
213+
"AssociationSummaries": [
214+
{
215+
"SourceArn": "a:b:c:d:e:context/src-arn-2",
216+
"SourceName": "source-name-2",
217+
"SourceType": "source-type-2",
218+
"DestinationArn": "a:b:c:d:e:context/dest-arn-2",
219+
"DestinationName": "dest-name-2",
220+
"DestinationType": "dest-type-2",
221+
"AssociationType": "type-2",
222+
}
223+
]
224+
},
225+
]
226+
227+
df = vizualizer.show(endpoint_arn=name)
228+
229+
sagemaker_session.sagemaker_client.list_contexts.assert_called_with(
230+
SourceUri=name,
231+
)
232+
233+
expected_calls = [
234+
unittest.mock.call(
235+
DestinationArn="context-arn",
236+
),
237+
unittest.mock.call(
238+
SourceArn="context-arn",
239+
),
240+
]
241+
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
242+
243+
expected_dataframe = pd.DataFrame.from_dict(
244+
OrderedDict(
245+
[
246+
("Name/Source", ["source-name-1", "dest-name-2"]),
247+
("Direction", ["Input", "Output"]),
248+
("Type", ["source-type-1", "dest-type-2"]),
249+
("Association Type", ["type-1", "type-2"]),
250+
("Lineage Type", ["context", "context"]),
251+
]
252+
)
253+
)
254+
255+
pd.testing.assert_frame_equal(expected_dataframe, df)

0 commit comments

Comments
 (0)