|
18 | 18 | from sagemaker.lineage.lineage_trial_component import LineageTrialComponent
|
19 | 19 | from sagemaker.lineage.query import LineageEntityEnum, LineageSourceEnum, Vertex, LineageQuery
|
20 | 20 | import pytest
|
| 21 | +import re |
21 | 22 |
|
22 | 23 |
|
23 | 24 | def test_lineage_query(sagemaker_session):
|
@@ -524,3 +525,53 @@ def test_vertex_to_object_unconvertable(sagemaker_session):
|
524 | 525 |
|
525 | 526 | with pytest.raises(ValueError):
|
526 | 527 | vertex.to_lineage_object()
|
| 528 | + |
| 529 | + |
| 530 | +def test_get_visualization_elements(sagemaker_session): |
| 531 | + lineage_query = LineageQuery(sagemaker_session) |
| 532 | + sagemaker_session.sagemaker_client.query_lineage.return_value = { |
| 533 | + "Vertices": [ |
| 534 | + {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, |
| 535 | + {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, |
| 536 | + ], |
| 537 | + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], |
| 538 | + } |
| 539 | + |
| 540 | + query_response = lineage_query.query( |
| 541 | + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] |
| 542 | + ) |
| 543 | + |
| 544 | + elements = query_response._get_visualization_elements() |
| 545 | + |
| 546 | + assert elements["nodes"][0] == ("arn1", "Endpoint", "Artifact", False) |
| 547 | + assert elements["nodes"][1] == ("arn2", "Model", "Context", False) |
| 548 | + assert elements["edges"][0] == ("arn1", "arn2", "Produced") |
| 549 | + |
| 550 | + |
| 551 | +def test_query_lineage_result_str(sagemaker_session): |
| 552 | + lineage_query = LineageQuery(sagemaker_session) |
| 553 | + sagemaker_session.sagemaker_client.query_lineage.return_value = { |
| 554 | + "Vertices": [ |
| 555 | + {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, |
| 556 | + {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, |
| 557 | + ], |
| 558 | + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], |
| 559 | + } |
| 560 | + |
| 561 | + query_response = lineage_query.query( |
| 562 | + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] |
| 563 | + ) |
| 564 | + |
| 565 | + response_str = query_response.__str__() |
| 566 | + pattern = "Mock id='\d*'" |
| 567 | + replace = "Mock id=''" |
| 568 | + response_str = re.sub(pattern, replace, response_str) |
| 569 | + |
| 570 | + assert ( |
| 571 | + response_str |
| 572 | + == "{\n'edges': [\n\t{'source_arn': 'arn1', 'destination_arn': 'arn2', 'association_type': 'Produced'}]," |
| 573 | + + "\n\n'vertices': [\n\t{'arn': 'arn1', 'lineage_entity': 'Artifact', 'lineage_source': 'Endpoint', " |
| 574 | + + "'_session': <Mock id=''>}, \n\t{'arn': 'arn2', 'lineage_entity': 'Context', 'lineage_source': " |
| 575 | + + "'Model', '_session': <Mock id=''>}],\n\n'startarn': " |
| 576 | + + "['arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext'],\n}" |
| 577 | + ) |
0 commit comments