Skip to content

Commit 3015410

Browse files
authored
change: reshape Artifacts into data frame in ExperimentsAnalytics
1 parent c9d89da commit 3015410

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed

src/sagemaker/analytics.py

+34
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,8 @@ def __init__(
431431
metric_names=None,
432432
parameter_names=None,
433433
sagemaker_session=None,
434+
input_artifact_names=None,
435+
output_artifact_names=None,
434436
):
435437
"""Initialize a ``ExperimentAnalytics`` instance.
436438
@@ -450,6 +452,11 @@ def __init__(
450452
sagemaker_session (sagemaker.session.Session): Session object which manages interactions
451453
with Amazon SageMaker APIs and any other AWS services needed. If not specified,
452454
one is created using the default AWS configuration chain.
455+
input_artifact_names(dict optional):The input artifacts for the experiment. Examples of
456+
input artifacts are datasets, algorithms, hyperparameters, source code, and instance
457+
types.
458+
output_artifact_names(dict optional): The output artifacts for the experiment. Examples
459+
of output artifacts are metrics, snapshots, logs, and images.
453460
"""
454461
sagemaker_session = sagemaker_session or Session()
455462
self._sage_client = sagemaker_session.sagemaker_client
@@ -463,6 +470,8 @@ def __init__(
463470
self._sort_order = sort_order
464471
self._metric_names = metric_names
465472
self._parameter_names = parameter_names
473+
self._input_artifact_names = input_artifact_names
474+
self._output_artifact_names = output_artifact_names
466475
self._trial_components = None
467476
super(ExperimentAnalytics, self).__init__()
468477
self.clear_cache()
@@ -516,6 +525,21 @@ def _reshape_metrics(self, metrics):
516525
out["{} - {}".format(metric_name, stat_type)] = stat_value
517526
return out
518527

528+
def _reshape_artifacts(self, artifacts, _artifact_names):
529+
"""Reshape trial component input/output artifacts to a pandas column
530+
Args:
531+
artifacts: trial component input/output artifacts
532+
Returns:
533+
dict: Key: artifacts name, Value: artifacts value
534+
"""
535+
out = OrderedDict()
536+
for name, value in sorted(artifacts.items()):
537+
if _artifact_names and (name not in _artifact_names):
538+
continue
539+
out["{} - {}".format(name, "MediaType")] = value.get("MediaType")
540+
out["{} - {}".format(name, "Value")] = value.get("Value")
541+
return out
542+
519543
def _reshape(self, trial_component):
520544
"""Reshape trial component data to pandas columns
521545
Args:
@@ -533,6 +557,16 @@ def _reshape(self, trial_component):
533557

534558
out.update(self._reshape_parameters(trial_component.get("Parameters", [])))
535559
out.update(self._reshape_metrics(trial_component.get("Metrics", [])))
560+
out.update(
561+
self._reshape_artifacts(
562+
trial_component.get("InputArtifacts", []), self._input_artifact_names
563+
)
564+
)
565+
out.update(
566+
self._reshape_artifacts(
567+
trial_component.get("OutputArtifacts", []), self._output_artifact_names
568+
)
569+
)
536570
return out
537571

538572
def _fetch_dataframe(self):

tests/integ/test_experiments_analytics.py

+59
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,65 @@ def experiment(sagemaker_session):
4343
_delete_resources(sm, experiment_name, trials)
4444

4545

46+
@contextmanager
47+
def experiment_with_artifacts(sagemaker_session):
48+
sm = sagemaker_session.sagemaker_client
49+
trials = {} # for resource cleanup
50+
51+
experiment_name = "experiment-" + str(uuid.uuid4())
52+
try:
53+
sm.create_experiment(ExperimentName=experiment_name)
54+
55+
# Search returns 10 results by default. Add 20 trials to verify pagination.
56+
for i in range(20):
57+
trial_name = "trial-" + str(uuid.uuid4())
58+
sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name)
59+
60+
trial_component_name = "tc-" + str(uuid.uuid4())
61+
trials[trial_name] = trial_component_name
62+
63+
sm.create_trial_component(
64+
TrialComponentName=trial_component_name, DisplayName="Training"
65+
)
66+
sm.update_trial_component(
67+
TrialComponentName=trial_component_name,
68+
Parameters={"hp1": {"NumberValue": i}},
69+
InputArtifacts={
70+
"inputArtifacts1": {"MediaType": "text/csv", "Value": "s3:/foo/bar1"}
71+
},
72+
OutputArtifacts={
73+
"outputArtifacts1": {"MediaType": "text/plain", "Value": "s3:/foo/bar2"}
74+
},
75+
)
76+
sm.associate_trial_component(
77+
TrialComponentName=trial_component_name, TrialName=trial_name
78+
)
79+
80+
time.sleep(15) # wait for search to get updated
81+
82+
yield experiment_name
83+
finally:
84+
_delete_resources(sm, experiment_name, trials)
85+
86+
87+
@pytest.mark.canary_quick
88+
def test_experiment_analytics_artifacts(sagemaker_session):
89+
with experiment_with_artifacts(sagemaker_session) as experiment_name:
90+
analytics = ExperimentAnalytics(
91+
experiment_name=experiment_name, sagemaker_session=sagemaker_session
92+
)
93+
94+
assert list(analytics.dataframe().columns) == [
95+
"TrialComponentName",
96+
"DisplayName",
97+
"hp1",
98+
"inputArtifacts1 - MediaType",
99+
"inputArtifacts1 - Value",
100+
"outputArtifacts1 - MediaType",
101+
"outputArtifacts1 - Value",
102+
]
103+
104+
46105
@pytest.mark.canary_quick
47106
def test_experiment_analytics(sagemaker_session):
48107
with experiment(sagemaker_session) as experiment_name:

tests/unit/test_experiments_analytics.py

+40
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ def trial_component(trial_component_name):
4040
"Count": 2.0,
4141
},
4242
],
43+
"InputArtifacts": {
44+
"inputArtifacts1": {"MediaType": "text/plain", "Value": "s3:/foo/bar1"},
45+
"inputArtifacts2": {"MediaType": "text/plain", "Value": "s3:/foo/bar2"},
46+
},
47+
"OutputArtifacts": {
48+
"outputArtifacts1": {"MediaType": "text/csv", "Value": "s3:/sky/far1"},
49+
"outputArtifacts2": {"MediaType": "text/csv", "Value": "s3:/sky/far2"},
50+
},
4351
}
4452

4553

@@ -72,6 +80,14 @@ def test_trial_analytics_dataframe_all_metrics_hyperparams(mock_session):
7280
("metric2 - StdDev", [0.05, 0.05]),
7381
("metric2 - Last", [7.0, 7.0]),
7482
("metric2 - Count", [2.0, 2.0]),
83+
("inputArtifacts1 - MediaType", ["text/plain", "text/plain"]),
84+
("inputArtifacts1 - Value", ["s3:/foo/bar1", "s3:/foo/bar1"]),
85+
("inputArtifacts2 - MediaType", ["text/plain", "text/plain"]),
86+
("inputArtifacts2 - Value", ["s3:/foo/bar2", "s3:/foo/bar2"]),
87+
("outputArtifacts1 - MediaType", ["text/csv", "text/csv"]),
88+
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
89+
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
90+
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
7591
]
7692
)
7793
)
@@ -117,6 +133,14 @@ def test_trial_analytics_dataframe_selected_hyperparams(mock_session):
117133
("metric2 - StdDev", [0.05, 0.05]),
118134
("metric2 - Last", [7.0, 7.0]),
119135
("metric2 - Count", [2.0, 2.0]),
136+
("inputArtifacts1 - MediaType", ["text/plain", "text/plain"]),
137+
("inputArtifacts1 - Value", ["s3:/foo/bar1", "s3:/foo/bar1"]),
138+
("inputArtifacts2 - MediaType", ["text/plain", "text/plain"]),
139+
("inputArtifacts2 - Value", ["s3:/foo/bar2", "s3:/foo/bar2"]),
140+
("outputArtifacts1 - MediaType", ["text/csv", "text/csv"]),
141+
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
142+
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
143+
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
120144
]
121145
)
122146
)
@@ -157,6 +181,14 @@ def test_trial_analytics_dataframe_selected_metrics(mock_session):
157181
("metric1 - StdDev", [1.0, 1.0]),
158182
("metric1 - Last", [2.0, 2.0]),
159183
("metric1 - Count", [2.0, 2.0]),
184+
("inputArtifacts1 - MediaType", ["text/plain", "text/plain"]),
185+
("inputArtifacts1 - Value", ["s3:/foo/bar1", "s3:/foo/bar1"]),
186+
("inputArtifacts2 - MediaType", ["text/plain", "text/plain"]),
187+
("inputArtifacts2 - Value", ["s3:/foo/bar2", "s3:/foo/bar2"]),
188+
("outputArtifacts1 - MediaType", ["text/csv", "text/csv"]),
189+
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
190+
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
191+
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
160192
]
161193
)
162194
)
@@ -203,6 +235,14 @@ def test_trial_analytics_dataframe_search_pagination(mock_session):
203235
("metric2 - StdDev", [0.05, 0.05]),
204236
("metric2 - Last", [7.0, 7.0]),
205237
("metric2 - Count", [2.0, 2.0]),
238+
("inputArtifacts1 - MediaType", ["text/plain", "text/plain"]),
239+
("inputArtifacts1 - Value", ["s3:/foo/bar1", "s3:/foo/bar1"]),
240+
("inputArtifacts2 - MediaType", ["text/plain", "text/plain"]),
241+
("inputArtifacts2 - Value", ["s3:/foo/bar2", "s3:/foo/bar2"]),
242+
("outputArtifacts1 - MediaType", ["text/csv", "text/csv"]),
243+
("outputArtifacts1 - Value", ["s3:/sky/far1", "s3:/sky/far1"]),
244+
("outputArtifacts2 - MediaType", ["text/csv", "text/csv"]),
245+
("outputArtifacts2 - Value", ["s3:/sky/far2", "s3:/sky/far2"]),
206246
]
207247
)
208248
)

0 commit comments

Comments
 (0)