|
14 | 14 | from __future__ import print_function, absolute_import
|
15 | 15 |
|
16 | 16 | from abc import ABCMeta, abstractmethod
|
17 |
| -from collections import defaultdict |
| 17 | +from collections import defaultdict, OrderedDict |
18 | 18 | import datetime
|
19 | 19 | import logging
|
20 | 20 |
|
|
23 | 23 | from sagemaker.session import Session
|
24 | 24 | from sagemaker.utils import DeferredError
|
25 | 25 |
|
| 26 | + |
26 | 27 | try:
|
27 | 28 | import pandas as pd
|
28 | 29 | except ImportError as e:
|
@@ -413,3 +414,197 @@ def _metric_names_for_training_job(self):
|
413 | 414 | metric_names = [md["Name"] for md in metric_definitions]
|
414 | 415 |
|
415 | 416 | return metric_names
|
| 417 | + |
| 418 | + |
| 419 | +class ExperimentAnalytics(AnalyticsMetricsBase): |
| 420 | + """Fetch trial component data and make them accessible for analytics. |
| 421 | + """ |
| 422 | + |
| 423 | + MAX_TRIAL_COMPONENTS = 10000 |
| 424 | + |
| 425 | + def __init__( |
| 426 | + self, |
| 427 | + experiment_name=None, |
| 428 | + search_expression=None, |
| 429 | + sort_by=None, |
| 430 | + sort_order=None, |
| 431 | + metric_names=None, |
| 432 | + parameter_names=None, |
| 433 | + sagemaker_session=None, |
| 434 | + ): |
| 435 | + """Initialize a ``ExperimentAnalytics`` instance. |
| 436 | +
|
| 437 | + Args: |
| 438 | + experiment_name (str, optional): Name of the experiment if you want to constrain the |
| 439 | + search to only trial components belonging to an experiment. |
| 440 | + search_expression (dict, optional): The search query to find the set of trial components |
| 441 | + to use to populate the data frame. |
| 442 | + sort_by (str, optional): The name of the resource property used to sort |
| 443 | + the set of trial components. |
| 444 | + sort_order(str optional): How trial components are ordered, valid values are Ascending |
| 445 | + and Descending. The default is Descending. |
| 446 | + metric_names (list, optional): string names of all the metrics to be shown in the |
| 447 | + data frame. If not specified, all metrics will be shown of all trials. |
| 448 | + parameter_names (list, optional): string names of the parameters to be shown in the |
| 449 | + data frame. If not specified, all parameters will be shown of all trials. |
| 450 | + sagemaker_session (sagemaker.session.Session): Session object which manages interactions |
| 451 | + with Amazon SageMaker APIs and any other AWS services needed. If not specified, |
| 452 | + one is created using the default AWS configuration chain. |
| 453 | + """ |
| 454 | + sagemaker_session = sagemaker_session or Session() |
| 455 | + self._sage_client = sagemaker_session.sagemaker_client |
| 456 | + |
| 457 | + if not experiment_name and not search_expression: |
| 458 | + raise ValueError("Either experiment_name or search_expression must be supplied.") |
| 459 | + |
| 460 | + self._experiment_name = experiment_name |
| 461 | + self._search_expression = search_expression |
| 462 | + self._sort_by = sort_by |
| 463 | + self._sort_order = sort_order |
| 464 | + self._metric_names = metric_names |
| 465 | + self._parameter_names = parameter_names |
| 466 | + self._trial_components = None |
| 467 | + super(ExperimentAnalytics, self).__init__() |
| 468 | + self.clear_cache() |
| 469 | + |
| 470 | + @property |
| 471 | + def name(self): |
| 472 | + """Name of the Experiment being analyzed |
| 473 | + """ |
| 474 | + return self._experiment_name |
| 475 | + |
| 476 | + def __repr__(self): |
| 477 | + return "<sagemaker.ExperimentAnalytics for %s>" % self.name |
| 478 | + |
| 479 | + def clear_cache(self): |
| 480 | + """Clear the object of all local caches of API methods. |
| 481 | + """ |
| 482 | + super(ExperimentAnalytics, self).clear_cache() |
| 483 | + self._trial_components = None |
| 484 | + |
| 485 | + def _reshape_parameters(self, parameters): |
| 486 | + """Reshape trial component parameters to a pandas column |
| 487 | + Args: |
| 488 | + parameters: trial component parameters |
| 489 | + Returns: |
| 490 | + dict: Key: Parameter name, Value: Parameter value |
| 491 | + """ |
| 492 | + out = OrderedDict() |
| 493 | + for name, value in sorted(parameters.items()): |
| 494 | + if self._parameter_names and name not in self._parameter_names: |
| 495 | + continue |
| 496 | + out[name] = value.get("NumberValue", value.get("StringValue")) |
| 497 | + return out |
| 498 | + |
| 499 | + def _reshape_metrics(self, metrics): |
| 500 | + """Reshape trial component metrics to a pandas column |
| 501 | + Args: |
| 502 | + metrics: trial component metrics |
| 503 | + Returns: |
| 504 | + dict: Key: Metric name, Value: Metric value |
| 505 | + """ |
| 506 | + statistic_types = ["Min", "Max", "Avg", "StdDev", "Last", "Count"] |
| 507 | + out = OrderedDict() |
| 508 | + for metric_summary in metrics: |
| 509 | + metric_name = metric_summary["MetricName"] |
| 510 | + if self._metric_names and metric_name not in self._metric_names: |
| 511 | + continue |
| 512 | + |
| 513 | + for stat_type in statistic_types: |
| 514 | + stat_value = metric_summary.get(stat_type) |
| 515 | + if stat_value is not None: |
| 516 | + out["{} - {}".format(metric_name, stat_type)] = stat_value |
| 517 | + return out |
| 518 | + |
| 519 | + def _reshape(self, trial_component): |
| 520 | + """Reshape trial component data to pandas columns |
| 521 | + Args: |
| 522 | + trial_component: dict representing a trial component |
| 523 | + Returns: |
| 524 | + dict: Key-Value pair representing the data in the pandas dataframe |
| 525 | + """ |
| 526 | + out = OrderedDict() |
| 527 | + for attribute in ["TrialComponentName", "DisplayName"]: |
| 528 | + out[attribute] = trial_component.get(attribute, "") |
| 529 | + |
| 530 | + source = trial_component.get("Source", "") |
| 531 | + if source: |
| 532 | + out["SourceArn"] = source["SourceArn"] |
| 533 | + |
| 534 | + out.update(self._reshape_parameters(trial_component.get("Parameters", []))) |
| 535 | + out.update(self._reshape_metrics(trial_component.get("Metrics", []))) |
| 536 | + return out |
| 537 | + |
| 538 | + def _fetch_dataframe(self): |
| 539 | + """Return a pandas dataframe with all the trial_components, |
| 540 | + along with their parameters and metrics. |
| 541 | + """ |
| 542 | + df = pd.DataFrame([self._reshape(component) for component in self._get_trial_components()]) |
| 543 | + return df |
| 544 | + |
| 545 | + def _get_trial_components(self, force_refresh=False): |
| 546 | + """ Get all trial components matching the given search query expression. |
| 547 | +
|
| 548 | + Args: |
| 549 | + force_refresh (bool): Set to True to fetch the latest data from SageMaker API. |
| 550 | +
|
| 551 | + Returns: |
| 552 | + list: List of dicts representing the trial components |
| 553 | + """ |
| 554 | + if force_refresh: |
| 555 | + self.clear_cache() |
| 556 | + if self._trial_components is not None: |
| 557 | + return self._trial_components |
| 558 | + |
| 559 | + if not self._search_expression: |
| 560 | + self._search_expression = {} |
| 561 | + |
| 562 | + if self._experiment_name: |
| 563 | + if not self._search_expression.get("Filters"): |
| 564 | + self._search_expression["Filters"] = [] |
| 565 | + |
| 566 | + self._search_expression["Filters"].append( |
| 567 | + { |
| 568 | + "Name": "Parents.ExperimentName", |
| 569 | + "Operator": "Equals", |
| 570 | + "Value": self._experiment_name, |
| 571 | + } |
| 572 | + ) |
| 573 | + |
| 574 | + return self._search(self._search_expression, self._sort_by, self._sort_order) |
| 575 | + |
| 576 | + def _search(self, search_expression, sort_by, sort_order): |
| 577 | + """ |
| 578 | + Perform a search query using SageMaker Search and return the matching trial components |
| 579 | +
|
| 580 | + Args: |
| 581 | + search_expression: Search expression to filter trial components. |
| 582 | + sort_by: The name of the resource property used to sort the trial components. |
| 583 | + sort_order: How trial components are ordered, valid values are Ascending |
| 584 | + and Descending. The default is Descending. |
| 585 | + Returns: |
| 586 | + list: List of dict representing trial components. |
| 587 | + """ |
| 588 | + trial_components = [] |
| 589 | + |
| 590 | + search_args = { |
| 591 | + "Resource": "ExperimentTrialComponent", |
| 592 | + "SearchExpression": search_expression, |
| 593 | + } |
| 594 | + |
| 595 | + if sort_by: |
| 596 | + search_args["SortBy"] = sort_by |
| 597 | + |
| 598 | + if sort_order: |
| 599 | + search_args["SortOrder"] = sort_order |
| 600 | + |
| 601 | + while len(trial_components) < self.MAX_TRIAL_COMPONENTS: |
| 602 | + search_response = self._sage_client.search(**search_args) |
| 603 | + components = [result["TrialComponent"] for result in search_response["Results"]] |
| 604 | + trial_components.extend(components) |
| 605 | + if "NextToken" in search_response and len(components) > 0: |
| 606 | + search_args["NextToken"] = search_response["NextToken"] |
| 607 | + else: |
| 608 | + break |
| 609 | + |
| 610 | + return trial_components |
0 commit comments