Skip to content

Commit 668359f

Browse files
yzhu0shreyapandit
authored andcommitted
fix: update lineage_trial_compoment get pipeline execution arn (#2944)
Co-authored-by: Shreya Pandit <[email protected]>
1 parent a928c0a commit 668359f

File tree

3 files changed

+104
-19
lines changed

3 files changed

+104
-19
lines changed

src/sagemaker/lineage/lineage_trial_component.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,15 @@ def pipeline_execution_arn(self) -> str:
130130
Returns:
131131
str: A pipeline execution ARN.
132132
"""
133+
trial_component = self.load(
134+
trial_component_name=self.trial_component_name, sagemaker_session=self.sagemaker_session
135+
)
136+
137+
if trial_component.source is None or trial_component.source["SourceArn"] is None:
138+
return None
139+
133140
tags = self.sagemaker_session.sagemaker_client.list_tags(
134-
ResourceArn=self.trial_component_arn
141+
ResourceArn=trial_component.source["SourceArn"]
135142
)["Tags"]
136143
for tag in tags:
137144
if tag["Key"] == "sagemaker:pipeline-execution-arn":

tests/integ/sagemaker/lineage/conftest.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def upstream_trial_associated_artifact(
233233
sagemaker_session=sagemaker_session,
234234
)
235235
trial_obj.add_trial_component(trial_component_obj)
236-
time.sleep(3)
236+
time.sleep(4)
237237
yield artifact_obj
238238
trial_obj.remove_trial_component(trial_component_obj)
239239
assntn.delete()
@@ -561,14 +561,14 @@ def static_approval_action(
561561

562562

563563
@pytest.fixture
564-
def static_model_deployment_action(sagemaker_session, static_endpoint_context):
564+
def static_model_deployment_action(sagemaker_session, static_processing_job_trial_component):
565565
query_filter = LineageFilter(
566566
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT]
567567
)
568568
query_result = LineageQuery(sagemaker_session).query(
569-
start_arns=[static_endpoint_context.context_arn],
569+
start_arns=[static_processing_job_trial_component.trial_component_arn],
570570
query_filter=query_filter,
571-
direction=LineageQueryDirectionEnum.ASCENDANTS,
571+
direction=LineageQueryDirectionEnum.DESCENDANTS,
572572
include_edges=False,
573573
)
574574
model_approval_actions = []
@@ -579,14 +579,14 @@ def static_model_deployment_action(sagemaker_session, static_endpoint_context):
579579

580580
@pytest.fixture
581581
def static_processing_job_trial_component(
582-
sagemaker_session, static_endpoint_context
582+
sagemaker_session, static_dataset_artifact
583583
) -> LineageTrialComponent:
584584
query_filter = LineageFilter(
585585
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB]
586586
)
587587

588588
query_result = LineageQuery(sagemaker_session).query(
589-
start_arns=[static_endpoint_context.context_arn],
589+
start_arns=[static_dataset_artifact.artifact_arn],
590590
query_filter=query_filter,
591591
direction=LineageQueryDirectionEnum.ASCENDANTS,
592592
include_edges=False,
@@ -600,14 +600,14 @@ def static_processing_job_trial_component(
600600

601601
@pytest.fixture
602602
def static_training_job_trial_component(
603-
sagemaker_session, static_endpoint_context
603+
sagemaker_session, static_model_artifact
604604
) -> LineageTrialComponent:
605605
query_filter = LineageFilter(
606606
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
607607
)
608608

609609
query_result = LineageQuery(sagemaker_session).query(
610-
start_arns=[static_endpoint_context.context_arn],
610+
start_arns=[static_model_artifact.artifact_arn],
611611
query_filter=query_filter,
612612
direction=LineageQueryDirectionEnum.ASCENDANTS,
613613
include_edges=False,
@@ -738,12 +738,12 @@ def static_dataset_artifact(static_model_artifact, sagemaker_session):
738738

739739

740740
@pytest.fixture
741-
def static_image_artifact(static_model_artifact, sagemaker_session):
741+
def static_image_artifact(static_dataset_artifact, sagemaker_session):
742742
query_filter = LineageFilter(
743743
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.IMAGE]
744744
)
745745
query_result = LineageQuery(sagemaker_session).query(
746-
start_arns=[static_model_artifact.artifact_arn],
746+
start_arns=[static_dataset_artifact.artifact_arn],
747747
query_filter=query_filter,
748748
direction=LineageQueryDirectionEnum.ASCENDANTS,
749749
include_edges=False,

tests/unit/sagemaker/lineage/test_lineage_trial_component.py

+86-8
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,39 @@ def test_pipeline_execution_arn(sagemaker_session):
114114
trial_component_arn = (
115115
"arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37"
116116
)
117-
obj = lineage_trial_component.LineageTrialComponent(
118-
sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn
117+
training_job_arn = (
118+
"arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain"
119119
)
120+
context = lineage_trial_component.LineageTrialComponent(
121+
sagemaker_session,
122+
trial_component_name="foo",
123+
trial_component_arn=trial_component_arn,
124+
source={
125+
"SourceArn": training_job_arn,
126+
"SourceType": "SageMakerTrainingJob",
127+
},
128+
)
129+
obj = {
130+
"TrialComponentName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job",
131+
"TrialComponentArn": trial_component_arn,
132+
"DisplayName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job",
133+
"Source": {
134+
"SourceArn": training_job_arn,
135+
"SourceType": "SageMakerTrainingJob",
136+
},
137+
}
138+
sagemaker_session.sagemaker_client.describe_trial_component.return_value = obj
120139

121140
sagemaker_session.sagemaker_client.list_tags.return_value = {
122141
"Tags": [
123142
{"Key": "sagemaker:pipeline-execution-arn", "Value": "tag1"},
124143
],
125144
}
126145
expected_calls = [
127-
unittest.mock.call(ResourceArn=trial_component_arn),
146+
unittest.mock.call(ResourceArn=training_job_arn),
128147
]
129-
pipeline_execution_arn_result = obj.pipeline_execution_arn()
148+
pipeline_execution_arn_result = context.pipeline_execution_arn()
149+
130150
assert pipeline_execution_arn_result == "tag1"
131151
assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls
132152

@@ -135,19 +155,77 @@ def test_no_pipeline_execution_arn(sagemaker_session):
135155
trial_component_arn = (
136156
"arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37"
137157
)
138-
obj = lineage_trial_component.LineageTrialComponent(
139-
sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn
158+
training_job_arn = (
159+
"arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain"
140160
)
161+
context = lineage_trial_component.LineageTrialComponent(
162+
sagemaker_session,
163+
trial_component_name="foo",
164+
trial_component_arn=trial_component_arn,
165+
source={
166+
"SourceArn": training_job_arn,
167+
"SourceType": "SageMakerTrainingJob",
168+
},
169+
)
170+
obj = {
171+
"TrialComponentName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job",
172+
"TrialComponentArn": trial_component_arn,
173+
"DisplayName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job",
174+
"Source": {
175+
"SourceArn": training_job_arn,
176+
"SourceType": "SageMakerTrainingJob",
177+
},
178+
}
179+
sagemaker_session.sagemaker_client.describe_trial_component.return_value = obj
141180

142181
sagemaker_session.sagemaker_client.list_tags.return_value = {
143182
"Tags": [
144183
{"Key": "abcd", "Value": "efg"},
145184
],
146185
}
147186
expected_calls = [
148-
unittest.mock.call(ResourceArn=trial_component_arn),
187+
unittest.mock.call(ResourceArn=training_job_arn),
149188
]
150-
pipeline_execution_arn_result = obj.pipeline_execution_arn()
189+
pipeline_execution_arn_result = context.pipeline_execution_arn()
190+
expected_result = None
191+
assert pipeline_execution_arn_result == expected_result
192+
assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls
193+
194+
195+
def test_no_source_arn_pipeline_execution_arn(sagemaker_session):
196+
trial_component_arn = (
197+
"arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37"
198+
)
199+
training_job_arn = (
200+
"arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain"
201+
)
202+
context = lineage_trial_component.LineageTrialComponent(
203+
sagemaker_session,
204+
trial_component_name="foo",
205+
trial_component_arn=trial_component_arn,
206+
source={
207+
"SourceArn": training_job_arn,
208+
"SourceType": "SageMakerTrainingJob",
209+
},
210+
)
211+
obj = {
212+
"TrialComponentName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job",
213+
"TrialComponentArn": trial_component_arn,
214+
"DisplayName": "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job",
215+
"Source": {
216+
"SourceArn": None,
217+
"SourceType": None,
218+
},
219+
}
220+
sagemaker_session.sagemaker_client.describe_trial_component.return_value = obj
221+
222+
sagemaker_session.sagemaker_client.list_tags.return_value = {
223+
"Tags": [
224+
{"Key": "abcd", "Value": "efg"},
225+
],
226+
}
227+
expected_calls = []
228+
pipeline_execution_arn_result = context.pipeline_execution_arn()
151229
expected_result = None
152230
assert pipeline_execution_arn_result == expected_result
153231
assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls

0 commit comments

Comments
 (0)