Skip to content

Commit 75fb75b

Browse files
author
Dan
authored
Merge branch 'master' into loc-config-file
2 parents 37e954a + 0da3339 commit 75fb75b

File tree

4 files changed

+77
-138
lines changed

4 files changed

+77
-138
lines changed

src/sagemaker/feature_store/feature_group.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def _ingest_single_batch(
185185
feature_name=data_frame.columns[index], value_as_string=str(row[index])
186186
)
187187
for index in range(len(row))
188+
if pd.notna(row[index])
188189
]
189190
sagemaker_session.put_record(
190191
feature_group_name=feature_group_name, record=[value.to_dict() for value in record]

tests/integ/test_feature_store.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def pandas_data_frame():
100100
"feature1": pd.Series(np.arange(10.0), dtype="float64"),
101101
"feature2": pd.Series(np.arange(10), dtype="int64"),
102102
"feature3": pd.Series(["2020-10-30T03:43:21Z"] * 10, dtype="string"),
103+
"feature4": pd.Series(np.arange(5.0), dtype="float64"), # contains nan
103104
}
104105
)
105106
return df
@@ -132,6 +133,7 @@ def create_table_ddl():
132133
" feature1 FLOAT\n"
133134
" feature2 INT\n"
134135
" feature3 STRING\n"
136+
" feature4 FLOAT\n"
135137
" write_time TIMESTAMP\n"
136138
" event_time TIMESTAMP\n"
137139
" is_deleted BOOLEAN\n"
@@ -214,6 +216,9 @@ def test_create_feature_store(
214216
time.sleep(60)
215217

216218
assert df.shape[0] == 11
219+
nans = pd.isna(df.loc[df["feature1"].isin([5, 6, 7, 8, 9])]["feature4"])
220+
for is_na in nans.items():
221+
assert is_na
217222
assert (
218223
create_table_ddl.format(
219224
feature_group_name=feature_group_name,

tests/integ/test_smdataparallel_tf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS,
3232
reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge",
3333
)
34+
@pytest.mark.skip("Failing due to bad DLC image release. Disable temporarily.")
3435
def test_smdataparallel_tf_mnist(
3536
sagemaker_session,
3637
tensorflow_training_latest_version,

tests/unit/sagemaker/lineage/test_visualizer.py

Lines changed: 70 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -49,64 +49,17 @@ def test_trial_component_name(viz, sagemaker_session):
4949
"TrialComponentArn": "tc-arn",
5050
}
5151

52-
sagemaker_session.sagemaker_client.list_associations.side_effect = [
53-
{
54-
"AssociationSummaries": [
55-
{
56-
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
57-
"SourceName": "source-name-1",
58-
"SourceType": "source-type-1",
59-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
60-
"DestinationName": "dest-name-1",
61-
"DestinationType": "dest-type-1",
62-
"AssociationType": "type-1",
63-
}
64-
]
65-
},
66-
{
67-
"AssociationSummaries": [
68-
{
69-
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
70-
"SourceName": "source-name-2",
71-
"SourceType": "source-type-2",
72-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
73-
"DestinationName": "dest-name-2",
74-
"DestinationType": "dest-type-2",
75-
"AssociationType": "type-2",
76-
}
77-
]
78-
},
79-
]
52+
get_list_associations_side_effect(sagemaker_session)
8053

8154
df = viz.show(trial_component_name=name)
8255

8356
sagemaker_session.sagemaker_client.describe_trial_component.assert_called_with(
8457
TrialComponentName=name,
8558
)
8659

87-
expected_calls = [
88-
unittest.mock.call(
89-
DestinationArn="tc-arn",
90-
),
91-
unittest.mock.call(
92-
SourceArn="tc-arn",
93-
),
94-
]
95-
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
60+
assert_list_associations_mock_calls(sagemaker_session)
9661

97-
expected_dataframe = pd.DataFrame.from_dict(
98-
OrderedDict(
99-
[
100-
("Name/Source", ["source-name-1", "dest-name-2"]),
101-
("Direction", ["Input", "Output"]),
102-
("Type", ["source-type-1", "dest-type-2"]),
103-
("Association Type", ["type-1", "type-2"]),
104-
("Lineage Type", ["artifact", "artifact"]),
105-
]
106-
)
107-
)
108-
109-
pd.testing.assert_frame_equal(expected_dataframe, df)
62+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
11063

11164

11265
def test_model_package_arn(viz, sagemaker_session):
@@ -116,34 +69,7 @@ def test_model_package_arn(viz, sagemaker_session):
11669
"ArtifactSummaries": [{"ArtifactArn": "artifact-arn"}]
11770
}
11871

119-
sagemaker_session.sagemaker_client.list_associations.side_effect = [
120-
{
121-
"AssociationSummaries": [
122-
{
123-
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
124-
"SourceName": "source-name-1",
125-
"SourceType": "source-type-1",
126-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
127-
"DestinationName": "dest-name-1",
128-
"DestinationType": "dest-type-1",
129-
"AssociationType": "type-1",
130-
}
131-
]
132-
},
133-
{
134-
"AssociationSummaries": [
135-
{
136-
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
137-
"SourceName": "source-name-2",
138-
"SourceType": "source-type-2",
139-
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
140-
"DestinationName": "dest-name-2",
141-
"DestinationType": "dest-type-2",
142-
"AssociationType": "type-2",
143-
}
144-
]
145-
},
146-
]
72+
get_list_associations_side_effect(sagemaker_session)
14773

14874
df = viz.show(model_package_arn=name)
14975

@@ -161,19 +87,7 @@ def test_model_package_arn(viz, sagemaker_session):
16187
]
16288
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
16389

164-
expected_dataframe = pd.DataFrame.from_dict(
165-
OrderedDict(
166-
[
167-
("Name/Source", ["source-name-1", "dest-name-2"]),
168-
("Direction", ["Input", "Output"]),
169-
("Type", ["source-type-1", "dest-type-2"]),
170-
("Association Type", ["type-1", "type-2"]),
171-
("Lineage Type", ["artifact", "artifact"]),
172-
]
173-
)
174-
)
175-
176-
pd.testing.assert_frame_equal(expected_dataframe, df)
90+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
17791

17892

17993
def test_endpoint_arn(viz, sagemaker_session):
@@ -183,34 +97,7 @@ def test_endpoint_arn(viz, sagemaker_session):
18397
"ContextSummaries": [{"ContextArn": "context-arn"}]
18498
}
18599

186-
sagemaker_session.sagemaker_client.list_associations.side_effect = [
187-
{
188-
"AssociationSummaries": [
189-
{
190-
"SourceArn": "a:b:c:d:e:context/src-arn-1",
191-
"SourceName": "source-name-1",
192-
"SourceType": "source-type-1",
193-
"DestinationArn": "a:b:c:d:e:context/dest-arn-1",
194-
"DestinationName": "dest-name-1",
195-
"DestinationType": "dest-type-1",
196-
"AssociationType": "type-1",
197-
}
198-
]
199-
},
200-
{
201-
"AssociationSummaries": [
202-
{
203-
"SourceArn": "a:b:c:d:e:context/src-arn-2",
204-
"SourceName": "source-name-2",
205-
"SourceType": "source-type-2",
206-
"DestinationArn": "a:b:c:d:e:context/dest-arn-2",
207-
"DestinationName": "dest-name-2",
208-
"DestinationType": "dest-type-2",
209-
"AssociationType": "type-2",
210-
}
211-
]
212-
},
213-
]
100+
get_list_associations_side_effect(sagemaker_session)
214101

215102
df = viz.show(endpoint_arn=name)
216103

@@ -228,27 +115,74 @@ def test_endpoint_arn(viz, sagemaker_session):
228115
]
229116
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
230117

231-
expected_dataframe = pd.DataFrame.from_dict(
232-
OrderedDict(
233-
[
234-
("Name/Source", ["source-name-1", "dest-name-2"]),
235-
("Direction", ["Input", "Output"]),
236-
("Type", ["source-type-1", "dest-type-2"]),
237-
("Association Type", ["type-1", "type-2"]),
238-
("Lineage Type", ["context", "context"]),
239-
]
240-
)
118+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
119+
120+
121+
def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
122+
123+
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
124+
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
125+
}
126+
127+
get_list_associations_side_effect(sagemaker_session)
128+
129+
step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
130+
131+
df = viz.show(pipeline_execution_step=step)
132+
133+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
134+
SourceArn="proc-job-arn",
241135
)
242136

243-
pd.testing.assert_frame_equal(expected_dataframe, df)
137+
assert_list_associations_mock_calls(sagemaker_session)
244138

139+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
245140

246-
def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
141+
142+
def test_training_job_pipeline_execution_step(viz, sagemaker_session):
247143

248144
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
249145
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
250146
}
251147

148+
get_list_associations_side_effect(sagemaker_session)
149+
150+
step = {"Metadata": {"TrainingJob": {"Arn": "training-job-arn"}}}
151+
152+
df = viz.show(pipeline_execution_step=step)
153+
154+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
155+
SourceArn="training-job-arn",
156+
)
157+
158+
assert_list_associations_mock_calls(sagemaker_session)
159+
160+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
161+
162+
163+
def test_transform_job_pipeline_execution_step(viz, sagemaker_session):
164+
165+
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
166+
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
167+
}
168+
169+
get_list_associations_side_effect(sagemaker_session)
170+
171+
step = {"Metadata": {"TransformJob": {"Arn": "transform-job-arn"}}}
172+
173+
df = viz.show(pipeline_execution_step=step)
174+
175+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
176+
SourceArn="transform-job-arn",
177+
)
178+
179+
assert_list_associations_mock_calls(sagemaker_session)
180+
181+
pd.testing.assert_frame_equal(get_expected_dataframe(), df)
182+
183+
184+
def get_list_associations_side_effect(sagemaker_session):
185+
252186
sagemaker_session.sagemaker_client.list_associations.side_effect = [
253187
{
254188
"AssociationSummaries": [
@@ -278,13 +212,8 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
278212
},
279213
]
280214

281-
step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
282-
283-
df = viz.show(pipeline_execution_step=step)
284215

285-
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
286-
SourceArn="proc-job-arn",
287-
)
216+
def assert_list_associations_mock_calls(sagemaker_session):
288217

289218
expected_calls = [
290219
unittest.mock.call(
@@ -296,6 +225,9 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
296225
]
297226
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
298227

228+
229+
def get_expected_dataframe():
230+
299231
expected_dataframe = pd.DataFrame.from_dict(
300232
OrderedDict(
301233
[
@@ -308,4 +240,4 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
308240
)
309241
)
310242

311-
pd.testing.assert_frame_equal(expected_dataframe, df)
243+
return expected_dataframe

0 commit comments

Comments
 (0)