Skip to content

Commit 8b5a3c0

Browse files
owen-tknakad
authored andcommitted
Eureka master (#273)
* Set eureka VERSION file * Eureka master (#145) * Eureka trial tracking interface * Add experiments developer guides (#147) * Add experiments developer guides * Eureka master (#148) * Add experiments developer guides * Move to new experiment / trial / trial run data model. Add TrialRun class. * Eureka master (#149) * Add Trial class * Experiment class (#151) * Introduce active-record design, include first concrete type - Experiment. * Add Trial and TrialRun active record classes (#152) * Add Trial and TrialRun active records. Add common created_time / last_modified_time to Record * List method (#153) * Add general list classmethod to Record. Add concrete impl to experiment.Experiment * Eureka master (#154) * Use general list method in all list* methods in experiment, trial, trial_run * Add Trial Run Tracker (#156) * Add Trial Run Tracker * Add helper methods to Experiment and Trial for fast-creating Trials and TrialRunTrackers * Set Eureka build to just run linters, doc build, and unit tests * Add integration tests (#159) * Add experiment and trial integegration tests * TrialRun bug fixes (#160) Introduce UpdatedData property on TrialRun * Adapt Python SDK to Experiments Api changes (#164) * Move experiment to new api * Pin pytest version to 4.4.1 * Fix list trials by experiment without passing experiment name. (#166) * Pass experiment name for list trial api call. * Eureka master (#167) Flatten trial name input for create_trial * Make providing the step name optional when creating a trial tracker (#169) * Make step name optional when creating tracker * Make create tracker obtain TRAINING_JOB_ARN from the environment. (#186) * Make create tracker obtain TRAINING_JOB_ARN from the environment. * when trial creates a tracker, the training_job_arn can be automatically set as source arn when creating the trial stage. * correpsonding unit test. * Add source_arn back as optional param for create_tracker * make source_arn an optional param for create_tracker function. * Resolving source arn skeleton for jobs. * Changing _resolve_job_arn to _resolve_source_arn. * Using generator to resolve source arn from the environment. * TrialAnalytics class to convert trial step data to pandas dataframe (#188) * TrialAnalytics class to convert trial stage data to pandas data frame * pin version flake8-future-import to 0.4.5 to avoid build failure on import annotations missing * pandas column ordering is different in py27 and py37. Sorting columns to make the order deterministic * * Rename step to component (#194) * Use new TrialComponent API structure for metrics, artifacts and parameters * Improve documentation (#199) * Improve documentation and remove unsupported parameters to list_trial_components * minor doc update * remove hardcoded alpha endpoint for experiments (#201) * Generate sphinx docs for experiment classes (#204) * Update Sphinx RST files to generate documentation for experiment classes * Merging master branch in to eureka-master (#206) * prepare release v1.18.16 * update development version to v1.18.17.dev0 * fix: use unique names for test training jobs (#765) * prepare release v1.18.17 * update development version to v1.18.18.dev0 * change: add automatic model tuning integ test for TF script mode (#766) * prepare release v1.18.18 * update development version to v1.18.19.dev0 * change: skip p2/p3 tests in eu-central-1 (#769) * prepare release v1.18.19 * update development version to v1.18.20.dev0 * feature: add document embedding support to Object2Vec algorithm (#772) * prepare release v1.19.0 * update development version to v1.19.1.dev0 * change: add py2 deprecation message for the deep learning framework images (#768) * prepare release v1.19.1 * update development version to v1.19.2.dev0 * feature: add RL Ray 0.6.5 support (#779) * fix: adjust Ray test script for Ray 0.6.5 (#781) * fix: prevent false positive PR test results (#783) * prepare release v1.20.0 * update development version to v1.20.1.dev0 * fix: update TrainingInputMode with s3_input InputMode (#776) * prepare release v1.20.1 * update development version to v1.20.2.dev0 * fix: pin pytest version to 4.4.1 to avoid pluggy version conflict (#788) * prepare release v1.20.2 * update development version to v1.20.3.dev0 * documentation: fix docs in regards to transform_fn for mxnet (#790) * fix: skip local file check for TF requirements file when source_dir is an S3 URI (#798) * fix: run tests if buildspec.yml has been modified (#786) * prepare release v1.20.3 * update development version to v1.20.4.dev0 * feature: Support for TFS preprocessing (#797) * prepare release v1.21.0 * update development version to v1.21.1.dev0 * fix: repack model function works without source directory (#804) * prepare release v1.21.1 * update development version to v1.21.2.dev0 * fix: emit training jobs tags to estimator (#803) * fix: set _current_job_name in attach() (#808) * prepare release v1.21.2 * update development version to v1.21.3.dev0 * fix: honor source_dir from S3 (#811) * feature: add encryption option to "record_set" (#794) * feature: add encryption option to "record_set" * prepare release v1.22.0 * update development version to v1.22.1.dev0 * documentation: update using_sklearn.rst parameter name (#814) Incorrect parameter name in docs. Updated to match what is implemented in the method and what is used in other estimators. * feature: support MXNet 1.4 with MMS (#812) * prepare release v1.23.0 * update development version to v1.23.1.dev0 * feature: add region check for Neo service (#806) * prepare release v1.24.0 * update development version to v1.24.1.dev0 * fix: add better default transform job name handling within Transformer (#822) * feature: repack_model support dependencies and code location (#821) * documentation: TFS support for pre/processing functions (#807) * change: skip p2 tests in ap-south-east (#823) * prepare release v1.25.0 * update development version to v1.25.1.dev0 * fix: use unique job name in hyperparameter tuning test (#829) * prepare release v1.25.1 * update development version to v1.25.2.dev0 * feature: Add extra_args to enable encrypted objects upload (#836) * change: downgrade c5 in integ tests and test all TF Script Mode images (#840) * feature: emit estimator transformer tags to model (#815) * doc: include FrameworkModel and ModelPackage in API docs (#833) * prepare release v1.26.0 * update development version to v1.26.1.dev0 * fix: fix logger creation in Chainer integ test script (#843) only one test failed due to a timeout. (the corresponding test failed with the other Python version.) talked to Rui offline. * feature: add wait argument to estimator deploy (#842) * prepare release v1.27.0 * update development version to v1.27.1.dev0 * feature: Add DataProcessing Fields for Batch Transform (#827) * prepare release v1.28.0 * update development version to v1.28.1.dev0 * Update setup.py (#859) * prepare release v1.28.1 * update development version to v1.28.2.dev0 * fix: prevent race condition in vpc tests (#863) * prepare release v1.28.2 * update development version to v1.28.3.dev0 * doc: clean up MXNet and TF documentation (#865) * doc: fix punctuation in MXNet version list (#866) * change: update Sagemaker Neo regions and instance families (#862) * prepare release v1.28.3 * update development version to v1.28.4.dev0 * feature: network isolation mode in training (#791) * feature: network isolation mode in training * feature: network isolation mode in tar support training * change: documentation and check describe training job network isolation * doc update * doc update, remove inference section * sourcedir * type error fix * change: moving not canary TFS tests to local mode (#870) * Integrate black into development process (#873) * change: Add Black formatting tool as dependency As of this commit, Black formatting tool can be run with 'tox -e black-format'. Black does not run as part of any automated process, yet. Black is pulled in as a test dependency only if the Python version is greater than 3.6, as the tool is not vended as part of any earlier Python version. * change: Resolve Black formatting failures Black is unable to handle trailing 'L' or 'l' which is no longer supported as of python 3.8. This commit removes those unnecessary 'long' identifiers. https://www.python.org/dev/peps/pep-0237/ * change: Format all files using Black This commit contains no functional changes. * change: Manually resolve flake8 violations after formatting * change: Manually resolve pylint violations after formatting * change: Enable black locally and in automated build. This commit enables black-format as part of "tox tests/unit", in order to format all files. It also enables black-check as part of the remote builds, in order to verify that all files are properly formatted. * prepare release v1.29.0 * update development version to v1.29.1.dev0 * feature: add git_config and git_clone, validate method (#832) * fix: add pytest.mark.local_mode annotation to broken tests (#876) * fix: add pytest.mark.local_mode annotation to tests * feature: add TensorFlow 1.13 support (#860) * prepare release v1.30.0 * update development version to v1.30.1.dev0 * fix: add pytest.mark.local_mode annotation to broken tests (#884) * change: remove unnecessary P3 tests from TFS integration tests (#885) * change: allow only one integration test run per time (#880) * change: Update buildspec.yml (#887) * feature: use deep learning images (#883) * prepare release v1.31.0 * update development version to v1.31.1.dev0 * change: build spec improvements. (#888) * fix: remove unnecessary failure case tests (#892) * change: print build execution time (#890) * prepare release v1.31.1 * update development version to v1.31.2.dev0 * fix git test in test_estimator.py (#894) * feature: support Endpoint_type for TF transform (#881) * prepare release v1.32.0 * update development version to v1.32.1.dev0 * change: separate unit, local mode, and notebook tests in different buildspecs (#898) * change: fix notebook tests (#900) * Update displaytime.sh (#901) * doc: refactor the overview topic in the sphinx project (#877) * change: tighten pylint config and expand C and R exceptions (#899) This commit tightens the pylint config with inspiration from several of Google's pylint configs. This commit also expands the C and R exceptions and disables the specific rules that cause issues in this package. * change: correct code per len-as-condition Pylint check (#902) The Pylint check is not actually enabled in this commit as it conflicts directly with NumPy. Pylint has corrected this, and it will be included in their next release (2.4.0): pylint-dev/pylint#2684 Once Pylint 2.4.0 is released, we can consume it and remove this check. A summary of this information is included in a TODO near the relevant Pylint disable rule (len-as-condition). * prepare release v1.32.1 * update development version to v1.32.2.dev0 * change: remove superfluous parens per Pylint rule (#903) * change: enable logging-format-interpolation pylint check (#904) * documentation: add pypi, rtd, black badges to readme (#910) * prepare release v1.32.2 * update development version to v1.32.3.dev0 * feature: allow custom model name during deploy (#792) * feature: allow custom model name during deploy * black check * feature: git support for hosting models (#878) * git integration for serving * fix: Add ap-northeast-1 to Neo algorithms region map (#897) * fix: reset default output path in Transformer.transform (#905) * fix: reset default output path on create transform job * Unit and integration tests * change: enable logging-not-lazy pylint check (#909) * change: enable wrong-import-position pylint check (#907) * change: enable wrong-import-position pylint check * change: updating import pattern for sagemaker.tensorflow * fix: fixing integration tests * change: reformatting * change: enable signature-differs pylint check (#915) * Revert "change: enable wrong-import-position pylint check (#907)" (#916) This reverts commit 8489f86. * change: enable wrong-import-position pylint check (#917) * change: remove TODO comment on import-error Pylint check (#918) By running Pylint before any of the unit tests (and dependency installs), the import-error check will always fail since the dependencies are not yet installed. We could move Pylint to a later stage to resolve this, but there's value in this quick check occurring before the unit tests. As a result, this Pylint check is being disabled. * prepare release v1.33.0 * update development version to v1.33.1.dev0 * change: enable unidiomatic-typecheck pylint check (#921) * change: enable no-else-return and no-else-raise pylint checks (#925) * change: fix list serialization for 1P algos (#922) * change: enable simplifiable-if-expression pylint checks (#926) * feature: deal with credentials for Git support for GitHub (#914) add authentication info * feature: Git integration for CodeCommit (#927) * add functions, tests and doc for CodeCommit * change: enable inconsistent-return-statements Pylint check (#930) Note that this commit also raises ValueErrors in situations that would previously have returned None. Per PEP8: Be consistent in return statements. Either all return statements in a function should return an expression, or none of them should. If any return statement returns an expression, any return statements where no value is returned should explicitly state this as return None, and an explicit return statement should be present at the end of the function (if reachable). * change: enable consider-merging-isinstance Pylint check (#932) Note that this commit will also enable simplifiable-if-statement, as there are no code changes needed for it. * change: enable attribute-defined-outside-init Pylint check (#933) The logic behind this rule is to improve readability by defining all the attributes of a class inside the init function, even if it simply sets them to None. * change: enable wrong-import-order Pylint check (#935) Per PEP8: Imports should be grouped in the following order: 1- Standard library imports. 2- Related third party imports. 3- Local application/library specific imports. * change: enable ungrouped-imports Pylint check (#936) * change: enable wrong-import-order Pylint check Per PEP8: Imports should be grouped in the following order: 1- Standard library imports. 2- Related third party imports. 3- Local application/library specific imports. * change: fix attach for 1P algorithm estimators (#931) * change: set num_processes_per_host only if provided by user (#928) * change: enable consider-using-in Pylint check (#938) * change: enable consider-using-in Pylint check * change: enable too-many-public-methods Pylint check (#939) * change: enable too-many-public-methods Pylint check This is a useful check to have, but is a lot of work to retroactively enforce. Enabling it while ignoring the single violation allows the validation to run for future code. * change: enable chained-comparison Pylint check (#940) * change: enable consider-using-ternary Pylint check (#942) This commit will add an exclusion for all auto-generated files. I chose to ignore the single violation, because the alternative is confusingly convoluted: `(hasattr(obj, '__getitem__') if hasattr(obj, '__iter__') else isinstance(obj, str))` * change: modify TODO on disabled Pylint check (#943) The check recommendations are only valid for packages that exclusively support Python 3. The changes cannot be made in Python 2. The TODO was updated to clarify this. * prepare release v1.34.0 * update development version to v1.34.1.dev0 * change: add MXNet 1.4.1 support (#886) * change: format and add missing docstring placeholders (#945) This commit will format all existing docstring to follow Google style: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html This commit will also add docstring placeholders to any class or method previously missing it. An ideal approach would be to take the time to include meaningful docstrings in every file. However, since that is not a task that will be prioritized, I've declared docstring bankruptcy on this package, in order to enforce docstring on all future code changes to this package. * change: allow serving script to be defined for deploy() and transformer() with frameworks (#944) * change: update PyTorch version (#947) * change: improve documentation of some functions (#864) [pr-827][followups] Improve documentation of some functions. Also some unit test fixes. See comments from marcio in #827 * doc: update using_tensorflow topic (#946) * fix: update TensorFlow script mode dependency list (#869) * change: improving Chainer integ tests (#872) * change: enable line-too-long Pylint check (#948) * doc: add instructions for setting up Cloud9 environment. (#949) Added instructions that allow for a low-cost ~10min environment setup. * prepare release v1.34.1 * update development version to v1.34.2.dev0 * change: Replaced generic ValueError with custom subclass when reporting unexpected resource status (#919) * doc: correct wording for Cloud9 environment setup instructions (#952) package => repo * change: removing unnecessary tests cases (#951) * prepare release v1.34.2 * update development version to v1.34.3.dev0 * change: waiting for training tags to propagate in the test (#955) * prepare release v1.34.3 * update development version to v1.34.4.dev0 * feature: allow serving image to be specified when calling MXNet.deploy (#959) * prepare release v1.35.0 * update development version to v1.35.1.dev0 * doc: refactor and edit using_mxnet topic (#956) * doc: refactor overview section per improvement plan * Update doc/overview.rst Co-Authored-By: Marcio Vinicius dos Santos <[email protected]> * Update doc/overview.rst Co-Authored-By: Marcio Vinicius dos Santos <[email protected]> * doc: made changes per feedback comments * doc: remove duplicate faq section and fixed heading * doc: fix heading levels in overview.rst * doc: update TensorFlow using topic * doc: Update using_tf.rst to address feedback * doc: fix comment in conf.py per build log * doc: add newline to conf.py to fix error * doc: addressed feedback for PR * doc: update conf.py * doc: remove duplicate byom section in overview.rst * doc: remove duplicate headings in several rst files * doc: Restructure and update Using MXNet topic * doc: fix link * doc: add link to mxnet readme container section in using_mxnet.rst topic * fix: update sklearn document to include 3p dependency installation (#960) * prepare release v1.35.1 * update development version to v1.35.2.dev0 * fix: allow Airflow enabled estimators to use absolute path entry_point (#965) * change: ignore FI18 flake8 rule (#969) * feature: support for TensorFlow 1.14 (#967) * flake8 fixes * black reformat * Revert "Merging master branch in to eureka-master (#206)" This reverts commit 080d06d561aa88a177c67f08114902ab292f3883. * Black + Pylint fixes * add latest api service model * skip eureka integ tests temporarily * Fix integ tests to work with preview SDK model (#215) * Fix integ tests to work with preview SDK model. * Use search to find trial components for analytics dataframe (#219) * move boto client arg to end of the arg list for all eureka APIs * Use search to find trial components in TrialAnalytics * add test to verify value error is thrown if no component filter specified * drop trial name from analytics frame as trial components wont have trial name in them in the future * remove trial name column for analytics frame * Eureka master (#236) * Add ExperimentConfig for estimator.fit and transformer.transform * experiment_config can be passed to estimator.fit * experiment_config can be passed to transformer.transform * unit tests for corresponding changes. * Remove include only experiment integ tests from tox.ini * Delete experiments integ tests * Update the service-2.json. * Bring in latest sagemaker models * Remove internal-only shapes and internal operations * Doc the three optional keys for ExperimentConfig dictionary * Eureka master (#237) * Add ExperimentConfig for estimator.fit and transformer.transform * experiment_config can be passed to estimator.fit * experiment_config can be passed to transformer.transform * unit tests for corresponding changes. * Remove include only experiment integ tests from tox.ini * Delete experiments integ tests * Update the service-2.json. * Bring in latest sagemaker models * Remove internal-only shapes and internal operations * Doc the three optional keys for ExperimentConfig dictionary * Fix analytics component and search functionality * Delete all experiments related classes and their tests. * Change TrialAnalytics to ExperimentAnalytics. * Fix ExperimentAnalytics for m-n model change. * Fix/Modify Search functionality * Fix Docs * Remove exp management doc from index * Fix pass None type ExperimentConfig to transform request. * Fix formatting in test_session.py * Do not build empty filters list when experiment name is not given * Add DisplayName to analytics table * Fix formatting. * Add sortBy and sortOrder for ExperimentAnalytics * Eureka master (#259) Fix bad merge * Add ExperimentConfig to Processor (#260) * Add ExperimentConfig to Processor * Remove broken experiment config from processor test (#261) * Add ExperimentConfig to Processor * Eureka master (#262) * Remove old setup file and Eureka specific files. * Eureka master (#264) * Add back missing factorization machines integration test * Minor style fixes (#265) * Minor style fixes * Fix broken SageMaker Experiments analytics integration tests (#267) * Fix broken experiments_analytics integration tests * Eureka master (#270) * Remove experiment_config from analytics test
1 parent 7d2aae8 commit 8b5a3c0

25 files changed

+942
-29
lines changed

doc/analytics.rst

+5
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,8 @@ Analytics
1515
:members:
1616
:undoc-members:
1717
:show-inheritance:
18+
19+
.. autoclass:: sagemaker.analytics.ExperimentAnalytics
20+
:members:
21+
:undoc-members:
22+
:show-inheritance:

doc/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,5 @@ SageMaker APIs to export configurations for creating and managing Airflow workfl
175175
:maxdepth: 2
176176

177177
sagemaker.workflow.airflow
178+
179+

src/sagemaker/amazon/amazon_estimator.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,15 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
178178
self.feature_dim = feature_dim
179179
self.mini_batch_size = mini_batch_size
180180

181-
def fit(self, records, mini_batch_size=None, wait=True, logs=True, job_name=None):
181+
def fit(
182+
self,
183+
records,
184+
mini_batch_size=None,
185+
wait=True,
186+
logs=True,
187+
job_name=None,
188+
experiment_config=None,
189+
):
182190
"""Fit this Estimator on serialized Record objects, stored in S3.
183191
184192
``records`` should be an instance of :class:`~RecordSet`. This
@@ -206,10 +214,16 @@ def fit(self, records, mini_batch_size=None, wait=True, logs=True, job_name=None
206214
job_name (str): Training job name. If not specified, the estimator
207215
generates a default job name, based on the training image name
208216
and current timestamp.
217+
experiment_config (dict[str, str]): Experiment management configuration.
218+
Dictionary contains three optional keys, 'ExperimentName',
219+
'TrialName', and 'TrialComponentName'
220+
(default: ``None``).
209221
"""
210222
self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
211223

212-
self.latest_training_job = _TrainingJob.start_new(self, records)
224+
self.latest_training_job = _TrainingJob.start_new(
225+
self, records, experiment_config=experiment_config
226+
)
213227
if wait:
214228
self.latest_training_job.wait(logs=logs)
215229

src/sagemaker/analytics.py

+196-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import print_function, absolute_import
1515

1616
from abc import ABCMeta, abstractmethod
17-
from collections import defaultdict
17+
from collections import defaultdict, OrderedDict
1818
import datetime
1919
import logging
2020

@@ -23,6 +23,7 @@
2323
from sagemaker.session import Session
2424
from sagemaker.utils import DeferredError
2525

26+
2627
try:
2728
import pandas as pd
2829
except ImportError as e:
@@ -413,3 +414,197 @@ def _metric_names_for_training_job(self):
413414
metric_names = [md["Name"] for md in metric_definitions]
414415

415416
return metric_names
417+
418+
419+
class ExperimentAnalytics(AnalyticsMetricsBase):
420+
"""Fetch trial component data and make them accessible for analytics.
421+
"""
422+
423+
MAX_TRIAL_COMPONENTS = 10000
424+
425+
def __init__(
426+
self,
427+
experiment_name=None,
428+
search_expression=None,
429+
sort_by=None,
430+
sort_order=None,
431+
metric_names=None,
432+
parameter_names=None,
433+
sagemaker_session=None,
434+
):
435+
"""Initialize a ``ExperimentAnalytics`` instance.
436+
437+
Args:
438+
experiment_name (str, optional): Name of the experiment if you want to constrain the
439+
search to only trial components belonging to an experiment.
440+
search_expression (dict, optional): The search query to find the set of trial components
441+
to use to populate the data frame.
442+
sort_by (str, optional): The name of the resource property used to sort
443+
the set of trial components.
444+
sort_order(str optional): How trial components are ordered, valid values are Ascending
445+
and Descending. The default is Descending.
446+
metric_names (list, optional): string names of all the metrics to be shown in the
447+
data frame. If not specified, all metrics will be shown of all trials.
448+
parameter_names (list, optional): string names of the parameters to be shown in the
449+
data frame. If not specified, all parameters will be shown of all trials.
450+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions
451+
with Amazon SageMaker APIs and any other AWS services needed. If not specified,
452+
one is created using the default AWS configuration chain.
453+
"""
454+
sagemaker_session = sagemaker_session or Session()
455+
self._sage_client = sagemaker_session.sagemaker_client
456+
457+
if not experiment_name and not search_expression:
458+
raise ValueError("Either experiment_name or search_expression must be supplied.")
459+
460+
self._experiment_name = experiment_name
461+
self._search_expression = search_expression
462+
self._sort_by = sort_by
463+
self._sort_order = sort_order
464+
self._metric_names = metric_names
465+
self._parameter_names = parameter_names
466+
self._trial_components = None
467+
super(ExperimentAnalytics, self).__init__()
468+
self.clear_cache()
469+
470+
@property
471+
def name(self):
472+
"""Name of the Experiment being analyzed
473+
"""
474+
return self._experiment_name
475+
476+
def __repr__(self):
477+
return "<sagemaker.ExperimentAnalytics for %s>" % self.name
478+
479+
def clear_cache(self):
480+
"""Clear the object of all local caches of API methods.
481+
"""
482+
super(ExperimentAnalytics, self).clear_cache()
483+
self._trial_components = None
484+
485+
def _reshape_parameters(self, parameters):
486+
"""Reshape trial component parameters to a pandas column
487+
Args:
488+
parameters: trial component parameters
489+
Returns:
490+
dict: Key: Parameter name, Value: Parameter value
491+
"""
492+
out = OrderedDict()
493+
for name, value in sorted(parameters.items()):
494+
if self._parameter_names and name not in self._parameter_names:
495+
continue
496+
out[name] = value.get("NumberValue", value.get("StringValue"))
497+
return out
498+
499+
def _reshape_metrics(self, metrics):
500+
"""Reshape trial component metrics to a pandas column
501+
Args:
502+
metrics: trial component metrics
503+
Returns:
504+
dict: Key: Metric name, Value: Metric value
505+
"""
506+
statistic_types = ["Min", "Max", "Avg", "StdDev", "Last", "Count"]
507+
out = OrderedDict()
508+
for metric_summary in metrics:
509+
metric_name = metric_summary["MetricName"]
510+
if self._metric_names and metric_name not in self._metric_names:
511+
continue
512+
513+
for stat_type in statistic_types:
514+
stat_value = metric_summary.get(stat_type)
515+
if stat_value is not None:
516+
out["{} - {}".format(metric_name, stat_type)] = stat_value
517+
return out
518+
519+
def _reshape(self, trial_component):
520+
"""Reshape trial component data to pandas columns
521+
Args:
522+
trial_component: dict representing a trial component
523+
Returns:
524+
dict: Key-Value pair representing the data in the pandas dataframe
525+
"""
526+
out = OrderedDict()
527+
for attribute in ["TrialComponentName", "DisplayName"]:
528+
out[attribute] = trial_component.get(attribute, "")
529+
530+
source = trial_component.get("Source", "")
531+
if source:
532+
out["SourceArn"] = source["SourceArn"]
533+
534+
out.update(self._reshape_parameters(trial_component.get("Parameters", [])))
535+
out.update(self._reshape_metrics(trial_component.get("Metrics", [])))
536+
return out
537+
538+
def _fetch_dataframe(self):
539+
"""Return a pandas dataframe with all the trial_components,
540+
along with their parameters and metrics.
541+
"""
542+
df = pd.DataFrame([self._reshape(component) for component in self._get_trial_components()])
543+
return df
544+
545+
def _get_trial_components(self, force_refresh=False):
546+
""" Get all trial components matching the given search query expression.
547+
548+
Args:
549+
force_refresh (bool): Set to True to fetch the latest data from SageMaker API.
550+
551+
Returns:
552+
list: List of dicts representing the trial components
553+
"""
554+
if force_refresh:
555+
self.clear_cache()
556+
if self._trial_components is not None:
557+
return self._trial_components
558+
559+
if not self._search_expression:
560+
self._search_expression = {}
561+
562+
if self._experiment_name:
563+
if not self._search_expression.get("Filters"):
564+
self._search_expression["Filters"] = []
565+
566+
self._search_expression["Filters"].append(
567+
{
568+
"Name": "Parents.ExperimentName",
569+
"Operator": "Equals",
570+
"Value": self._experiment_name,
571+
}
572+
)
573+
574+
return self._search(self._search_expression, self._sort_by, self._sort_order)
575+
576+
def _search(self, search_expression, sort_by, sort_order):
577+
"""
578+
Perform a search query using SageMaker Search and return the matching trial components
579+
580+
Args:
581+
search_expression: Search expression to filter trial components.
582+
sort_by: The name of the resource property used to sort the trial components.
583+
sort_order: How trial components are ordered, valid values are Ascending
584+
and Descending. The default is Descending.
585+
Returns:
586+
list: List of dict representing trial components.
587+
"""
588+
trial_components = []
589+
590+
search_args = {
591+
"Resource": "ExperimentTrialComponent",
592+
"SearchExpression": search_expression,
593+
}
594+
595+
if sort_by:
596+
search_args["SortBy"] = sort_by
597+
598+
if sort_order:
599+
search_args["SortOrder"] = sort_order
600+
601+
while len(trial_components) < self.MAX_TRIAL_COMPONENTS:
602+
search_response = self._sage_client.search(**search_args)
603+
components = [result["TrialComponent"] for result in search_response["Results"]]
604+
trial_components.extend(components)
605+
if "NextToken" in search_response and len(components) > 0:
606+
search_args["NextToken"] = search_response["NextToken"]
607+
else:
608+
break
609+
610+
return trial_components

src/sagemaker/estimator.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def _prepare_collection_configs(self):
384384
if self.debugger_hook_config is not None:
385385
self.collection_configs.update(self.debugger_hook_config.collection_configs or [])
386386

387-
def fit(self, inputs=None, wait=True, logs="All", job_name=None):
387+
def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None):
388388
"""Train a model using the input training dataset.
389389
390390
The API calls the Amazon SageMaker CreateTrainingJob API to start
@@ -418,10 +418,14 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None):
418418
Only meaningful when wait is True.
419419
job_name (str): Training job name. If not specified, the estimator generates
420420
a default job name, based on the training image name and current timestamp.
421+
experiment_config (dict[str, str]): Experiment management configuration.
422+
Dictionary contains three optional keys,
423+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
424+
421425
"""
422426
self._prepare_for_training(job_name=job_name)
423427

424-
self.latest_training_job = _TrainingJob.start_new(self, inputs)
428+
self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
425429
self.jobs.append(self.latest_training_job)
426430
if wait:
427431
self.latest_training_job.wait(logs=logs)
@@ -896,14 +900,18 @@ class _TrainingJob(_Job):
896900
"""Placeholder docstring"""
897901

898902
@classmethod
899-
def start_new(cls, estimator, inputs):
903+
def start_new(cls, estimator, inputs, experiment_config):
900904
"""Create a new Amazon SageMaker training job from the estimator.
901905
902906
Args:
903907
estimator (sagemaker.estimator.EstimatorBase): Estimator object
904908
created by the user.
905909
inputs (str): Parameters used when called
906910
:meth:`~sagemaker.estimator.EstimatorBase.fit`.
911+
experiment_config (dict[str, str]): Experiment management configuration used when called
912+
:meth:`~sagemaker.estimator.EstimatorBase.fit`. Dictionary contains
913+
three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
914+
907915
908916
Returns:
909917
sagemaker.estimator._TrainingJob: Constructed object that captures
@@ -931,6 +939,7 @@ def start_new(cls, estimator, inputs):
931939
train_args["hyperparameters"] = hyperparameters
932940
train_args["tags"] = estimator.tags
933941
train_args["metric_definitions"] = estimator.metric_definitions
942+
train_args["experiment_config"] = experiment_config
934943

935944
if isinstance(inputs, s3_input):
936945
if "InputMode" in inputs.config:

0 commit comments

Comments
 (0)