Skip to content

Commit 2c3d606

Browse files
knikuremakungaj1CaptainiaJonathan Makungaevakravi
authored
feat: Benchmark feature initial commit (aws#1463)
* Sync Master benchmark feature (aws#1461) * feat: support config_name in all JumpStart interfaces (aws#4583) (aws#4607) * add-config-name * address comments * updates for set config * docstyle * updates * fix * format * format * remove tests * Add ReadOnly APIs (aws#4606) * Add ReadOnly APIs * Resolving PR review comments * Resolve PR review comments * Refactoring * Refactoring * Add Caching * Refactore * Resolving conflicts * Add Unit Tests * Fix Unit Tests * Fix unit tests * Fix UT * Refactoring * Fix Integ tests * refactoring after Notebook testing * Fix code styles --------- Co-authored-by: Jonathan Makunga <[email protected]> * feat: tag JumpStart resource with config names (aws#4608) * tag config name * format * resolving comments * format * format * update * fix * format * updates inference component config name * fix: tests * ModelBuilder: Add functionalities to get and set deployment config. (aws#4614) * Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage --------- Co-authored-by: Jonathan Makunga <[email protected]> * Benchmark feature v2 (aws#4618) * Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage * Testing fix with Notebook * Only fetch instance rate metrics if not present * Increase code coverage --------- Co-authored-by: Jonathan Makunga <[email protected]> * fix: populate default config name to model (aws#4617) * fix: populate default config name to model * update condition * fix * format * flake8 * fix tests * fix coverage * temporarily skip integ test vulnerbility * fix tolerate attach method * format * fix predictor * format * Fix fetch instance rate bug (aws#4624) Co-authored-by: Jonathan Makunga <[email protected]> * chore: require config name and instance type in set_deployment_config (aws#4625) * require config_name and instance_type in set config * docstring * add supported instance types check * add more tests * format * fix tests * Deployment Configs - Follow-ups (aws#4626) * Init Deployment configs outside Model init. * Testing with NB * Testing with NB-V2 * Refactoring, NB testing * NB Testing and Refactoring * Testing * Refactoring * Testing with NB * Debug * Debug display API * Debug with NB * Testing with NB * Refactoring * Refactoring * Refactoring and NB testing * Testing with NB * Refactoring * Prefix instance type with ml * Fix unit tests --------- Co-authored-by: Jonathan Makunga <[email protected]> * fix: use different separator to flatten dict (aws#4629) * Use separate tags for inference and training configs (aws#4635) * Use separate tags for inference and training * format * format * format * format * Add supported inference and incremental training configs (aws#4637) * supported inference configs * add tests * format * tests * tests * address comments * format and address comments * updates * formt * format * Benchmark feature fixes (aws#4632) * Filter down Benchmark Metrics * Filter down Benchmark Metrics * Testing NB * Testing MB * Testing * Refactoring * Unit tests * Display instance type first, and instance rate last * Display unbalanced metrics * Testing with NB * Testing with NB * Debug * Debug * Testing with NB * Testing with NB * Testing with NB * Refactoring * Refactoring * Refactoring * Unit tests * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Refactoring * Debug * Config ranking * Debug * Debug * Debug * Debug * Debug * Ranking * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Debug * Debug * Debug * Debug * Refactoring * Contact JumpStart team to fix flaky test. test_list_jumpstart_models_script_filter --------- Co-authored-by: Jonathan Makunga <[email protected]> * fix: typo and merge with master branch (aws#4649) * Merge master into benchmark feature (aws#4652) * Merge master into master-benchmark-feature (aws#4656) * Master benchmark feature (aws#4658) * Master benchmark feature merge master (aws#4661) * Master benchmark feature (aws#4672) * fix: mainline alt config parsing (aws#4602) * fix: parsing * fix: commit tests * fix: types * updated * fix * Add Triton v24.03 URI (aws#4605) Co-authored-by: Nikhil Kulkarni <[email protected]> * feature: support session tag chaining for training job (aws#4596) * feature: support session tag chaining for training job * fix: resolve typo * fix: resolve typo and build failure * fix: resolve typo and unit test failure --------- Co-authored-by: Jessica Zhu <jessicazhu3@[email protected]> * prepare release v2.217.0 * update development version to v2.217.1.dev0 * fix: properly close files in lineage queries and tests (aws#4587) Closes aws#4458 * feature: set default allow_pickle param to False (aws#4557) * breaking: set default allow_pickle param to False * breaking: fix unit tests and linting NumpyDeserializer will not allow deserialization unless allow_pickle flag is set to True explicitly * fix: black-check --------- Co-authored-by: Ashwin Krishna <[email protected]> * Fix:invalid component error with new metadata (aws#4634) * fix: invalid component name * tests * format * fix vulnerable model integ tests llama 2 * updated * fix: training dataset location * prepare release v2.218.0 * update development version to v2.218.1.dev0 * chore: update skipped flaky tests (aws#4644) * Update skipped flaky tests * flake8 * format * format * chore: release tgi 2.0.1 (aws#4642) * chore: release tgi 2.0.1 * minor fix --------- Co-authored-by: Zhaoqi <[email protected]> * fix: Fix UserAgent logging in Python SDK (aws#4647) * prepare release v2.218.1 * update development version to v2.218.2.dev0 * feature: allow choosing js payload by alias in private method * Updates for SMP v2.3.1 (aws#4660) Co-authored-by: Suhit Kodgule <[email protected]> * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /doc (aws#4655) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](pallets/jinja@3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump tqdm from 4.66.2 to 4.66.3 in /tests/data/serve_resources/mlflow/pytorch (aws#4650) Bumps [tqdm](https://github.com/tqdm/tqdm) from 4.66.2 to 4.66.3. - [Release notes](https://github.com/tqdm/tqdm/releases) - [Commits](tqdm/tqdm@v4.66.2...v4.66.3) --- updated-dependencies: - dependency-name: tqdm dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /requirements/extras (aws#4654) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](pallets/jinja@3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * prepare release v2.219.0 * update development version to v2.219.1.dev0 * fix: skip flakey tests pending investigation (aws#4667) * change: update image_uri_configs 05-09-2024 07:17:41 PST * Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models (aws#4662) * Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Initial commit for lineage impl * Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Add integ tests and uts * fix local mode for tf_serving * Allow lineage tracking only in sagemaker endpoint mode * fix regex pattern * fix style issues * fix regex pattern and hard coded py version in ut * fix missing session * Resolve pr comments and fix regex for mlflow registry and ids * fix: model builder race condition on sagemaker session (aws#4673) Co-authored-by: Jonathan Makunga <[email protected]> * feat: Add telemetry support for mlflow models (aws#4674) * Initial commit for telemetry support * Fix style issues and add more logger messages * fix value error messages in ut * feat: add new images for HF TGI release (aws#4677) * chore: add new images for HF TGI release * test * feature: AutoGluon 1.1.0 image_uris update (aws#4679) Co-authored-by: Ubuntu <[email protected]> * change: add debug logs to workflow container dist creation (aws#4682) * prepare release v2.220.0 * update development version to v2.220.1.dev0 * fix: Image URI should take precedence for HF models (aws#4684) * Fix: Image URI should take precedence for HF models * Fix formatting * Fix formatting * Fix formatting * Increase coverage - UT pass * feat: support config_name in all JumpStart interfaces (aws#4583) (aws#4607) * add-config-name * address comments * updates for set config * docstyle * updates * fix * format * format * remove tests * Add ReadOnly APIs (aws#4606) * Add ReadOnly APIs * Resolving PR review comments * Resolve PR review comments * Refactoring * Refactoring * Add Caching * Refactore * Resolving conflicts * Add Unit Tests * Fix Unit Tests * Fix unit tests * Fix UT * Refactoring * Fix Integ tests * refactoring after Notebook testing * Fix code styles --------- Co-authored-by: Jonathan Makunga <[email protected]> * feat: tag JumpStart resource with config names (aws#4608) * tag config name * format * resolving comments * format * format * update * fix * format * updates inference component config name * fix: tests * ModelBuilder: Add functionalities to get and set deployment config. (aws#4614) * Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage --------- Co-authored-by: Jonathan Makunga <[email protected]> * Benchmark feature v2 (aws#4618) * Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage * Testing fix with Notebook * Only fetch instance rate metrics if not present * Increase code coverage --------- Co-authored-by: Jonathan Makunga <[email protected]> * fix: populate default config name to model (aws#4617) * fix: populate default config name to model * update condition * fix * format * flake8 * fix tests * fix coverage * temporarily skip integ test vulnerbility * fix tolerate attach method * format * fix predictor * format * Fix fetch instance rate bug (aws#4624) Co-authored-by: Jonathan Makunga <[email protected]> * chore: require config name and instance type in set_deployment_config (aws#4625) * require config_name and instance_type in set config * docstring * add supported instance types check * add more tests * format * fix tests * Deployment Configs - Follow-ups (aws#4626) * Init Deployment configs outside Model init. * Testing with NB * Testing with NB-V2 * Refactoring, NB testing * NB Testing and Refactoring * Testing * Refactoring * Testing with NB * Debug * Debug display API * Debug with NB * Testing with NB * Refactoring * Refactoring * Refactoring and NB testing * Testing with NB * Refactoring * Prefix instance type with ml * Fix unit tests --------- Co-authored-by: Jonathan Makunga <[email protected]> * fix: use different separator to flatten dict (aws#4629) * Use separate tags for inference and training configs (aws#4635) * Use separate tags for inference and training * format * format * format * format * Add supported inference and incremental training configs (aws#4637) * supported inference configs * add tests * format * tests * tests * address comments * format and address comments * updates * formt * format * Benchmark feature fixes (aws#4632) * Filter down Benchmark Metrics * Filter down Benchmark Metrics * Testing NB * Testing MB * Testing * Refactoring * Unit tests * Display instance type first, and instance rate last * Display unbalanced metrics * Testing with NB * Testing with NB * Debug * Debug * Testing with NB * Testing with NB * Testing with NB * Refactoring * Refactoring * Refactoring * Unit tests * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Refactoring * Debug * Config ranking * Debug * Debug * Debug * Debug * Debug * Ranking * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Debug * Debug * Debug * Debug * Refactoring * Contact JumpStart team to fix flaky test. test_list_jumpstart_models_script_filter --------- Co-authored-by: Jonathan Makunga <[email protected]> * fix: typo and merge with master branch (aws#4649) * Merge master into benchmark feature (aws#4652) * Merge master into master-benchmark-feature (aws#4656) * Master benchmark feature (aws#4658) * Remove duplicate line in types.py * Remove duplicate lines * Remove duplicate lines * Remove duplicate lines * Remove duplicate lines * fix unit test --------- Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: Haotian An <[email protected]> Co-authored-by: Nikhil Kulkarni <[email protected]> Co-authored-by: Nikhil Kulkarni <[email protected]> Co-authored-by: jessicazhu3 <[email protected]> Co-authored-by: Jessica Zhu <jessicazhu3@[email protected]> Co-authored-by: ci <ci> Co-authored-by: Justin <[email protected]> Co-authored-by: ASHWIN KRISHNA <[email protected]> Co-authored-by: Ashwin Krishna <[email protected]> Co-authored-by: Haixin Wang <[email protected]> Co-authored-by: Zhaoqi <[email protected]> Co-authored-by: Kalyani Nikure <[email protected]> Co-authored-by: Keerthan Vasist <[email protected]> Co-authored-by: SuhitK <[email protected]> Co-authored-by: Suhit Kodgule <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: sagemaker-bot <[email protected]> Co-authored-by: jiapinw <[email protected]> Co-authored-by: Jonathan Makunga <[email protected]> Co-authored-by: Jonathan Makunga <[email protected]> Co-authored-by: Prateek M Desai <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]> Co-authored-by: Samrudhi Sharma <[email protected]> Co-authored-by: evakravi <[email protected]> * fix benchmark feature read-only apis (aws#4675) * Rearrange benchmark metric table * Refactoring * Refactoring * Refactoring * Refactoring * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Refactoring * Refactoring * Refactoring * Refactoring * Refactoring * Add Unit tests * Refactoring * Refactoring * hide index from DataFrame --------- Co-authored-by: Jonathan Makunga <[email protected]> * feat: update alt config to work with model packages (aws#4706) * feat: update alt config to work with model packages * format * remove env vars for model package * fix tests * Update: ReadOnly APIs (aws#4707) * Model data arn * Refactoring * Refactoring * acceleration_configs * Refactoring * UT * Add Filter * UT * Revert "UT" * UT * UT --------- Co-authored-by: Jonathan Makunga <[email protected]> * ModelBuilder to support display with filter. (aws#4712) Co-authored-by: Jonathan Makunga <[email protected]> * Sync branch (aws#4718) * fix: mainline alt config parsing (aws#4602) * fix: parsing * fix: commit tests * fix: types * updated * fix * Add Triton v24.03 URI (aws#4605) Co-authored-by: Nikhil Kulkarni <[email protected]> * feature: support session tag chaining for training job (aws#4596) * feature: support session tag chaining for training job * fix: resolve typo * fix: resolve typo and build failure * fix: resolve typo and unit test failure --------- Co-authored-by: Jessica Zhu <jessicazhu3@[email protected]> * prepare release v2.217.0 * update development version to v2.217.1.dev0 * fix: properly close files in lineage queries and tests (aws#4587) Closes aws#4458 * feature: set default allow_pickle param to False (aws#4557) * breaking: set default allow_pickle param to False * breaking: fix unit tests and linting NumpyDeserializer will not allow deserialization unless allow_pickle flag is set to True explicitly * fix: black-check --------- Co-authored-by: Ashwin Krishna <[email protected]> * Fix:invalid component error with new metadata (aws#4634) * fix: invalid component name * tests * format * fix vulnerable model integ tests llama 2 * updated * fix: training dataset location * prepare release v2.218.0 * update development version to v2.218.1.dev0 * chore: update skipped flaky tests (aws#4644) * Update skipped flaky tests * flake8 * format * format * chore: release tgi 2.0.1 (aws#4642) * chore: release tgi 2.0.1 * minor fix --------- Co-authored-by: Zhaoqi <[email protected]> * fix: Fix UserAgent logging in Python SDK (aws#4647) * prepare release v2.218.1 * update development version to v2.218.2.dev0 * feature: allow choosing js payload by alias in private method * Updates for SMP v2.3.1 (aws#4660) Co-authored-by: Suhit Kodgule <[email protected]> * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /doc (aws#4655) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](pallets/jinja@3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump tqdm from 4.66.2 to 4.66.3 in /tests/data/serve_resources/mlflow/pytorch (aws#4650) Bumps [tqdm](https://github.com/tqdm/tqdm) from 4.66.2 to 4.66.3. - [Release notes](https://github.com/tqdm/tqdm/releases) - [Commits](tqdm/tqdm@v4.66.2...v4.66.3) --- updated-dependencies: - dependency-name: tqdm dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /requirements/extras (aws#4654) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](pallets/jinja@3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * prepare release v2.219.0 * update development version to v2.219.1.dev0 * fix: skip flakey tests pending investigation (aws#4667) * change: update image_uri_configs 05-09-2024 07:17:41 PST * Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models (aws#4662) * Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Initial commit for lineage impl * Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Add integ tests and uts * fix local mode for tf_serving * Allow lineage tracking only in sagemaker endpoint mode * fix regex pattern * fix style issues * fix regex pattern and hard coded py version in ut * fix missing session * Resolve pr comments and fix regex for mlflow registry and ids * fix: model builder race condition on sagemaker session (aws#4673) Co-authored-by: Jonathan Makunga <[email protected]> * feat: Add telemetry support for mlflow models (aws#4674) * Initial commit for telemetry support * Fix style issues and add more logger messages * fix value error messages in ut * feat: add new images for HF TGI release (aws#4677) * chore: add new images for HF TGI release * test * feature: AutoGluon 1.1.0 image_uris update (aws#4679) Co-authored-by: Ubuntu <[email protected]> * change: add debug logs to workflow container dist creation (aws#4682) * prepare release v2.220.0 * update development version to v2.220.1.dev0 * fix: Image URI should take precedence for HF models (aws#4684) * Fix: Image URI should take precedence for HF models * Fix formatting * Fix formatting * Fix formatting * Increase coverage - UT pass * feat: onboard tei image config to pysdk (aws#4681) * feat: onboard tei image config to pysdk * fix formatting issue * minor fix func name * fix unit tests --------- Co-authored-by: Mufaddal Rohawala <[email protected]> * fix: model builder limited container support for endpoint mode. (aws#4683) * Allow ModelBuilder's endpoint mode for Jumpstart models packaged with containers other than TGI and DJL * increase coverage * Add JS Support for MMS Serving * Add JS Support for MMS Serving * Unit tests * Refactoring * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga <[email protected]> * change: Add more debuging (aws#4687) * change: cover tei with image_uris.retrieve API (aws#4689) * fix: JS Model with non-TGI/non-DJL deployment failure (aws#4688) * Debug * Debug * Debug * Debug * Debug * Debug * fix docstyle * Refactoring * Add Integ tests --------- Co-authored-by: Jonathan Makunga <[email protected]> * Feat: Pull latest tei container for sentence similiarity models on HuggingFace hub (aws#4686) * Update: Pull latest tei container for sentence similiarity models * Fix formatting * Address PR comments * Fix formatting * Fix check * Switch sentence similarity to be deployed on tgi * Fix formatting * Fix formatting * Fix formatting * Fix formatting * Introduce TEI builder with TGI server * Fix formmatting * Add integ test * Fix formatting * Add integ test * Add integ test * Add integ test * Add integ test * Add integ test * Fix formatting * Move to G5 for integ test * Fix formatting * Integ test updates * Integ test updates * Integ test updates * Fix formatting * Integ test updates * Move back to generate for ping * Integ test updates * Integ test updates * Fix: Add Image URI overrides for transformers models (aws#4693) * Fix: Add Image URI overrides for transformers models * Increase coverage * Fix formatting * prepare release v2.221.0 * update development version to v2.221.1.dev0 * Add tei cpu image (aws#4695) * Add tei cpu image * fix format issue * fix unit tests * fix typo * fix typo * Feat: Add TEI support for ModelBuilder (aws#4694) * Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Notebook testing * Notebook testing * Notebook testing * Refactoring * Refactoring * UT * UT * Refactoring * Test coverage * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga <[email protected]> * Convert pytorchddp distribution to smdistributed distribution (aws#4698) * rewrite pytorchddp to smdistributed * remove instance type check * Update estimator.py * remove validate_pytorch_distribution * fix * fix unit tests * fix formatting * check instance type not None * prepare release v2.221.1 * update development version to v2.221.2.dev0 * Update: SM Endpoint Routing Strategy Support. (aws#4702) * RoutingConfig * Refactoring * Docstring * UT * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga <[email protected]> * change: update image_uri_configs 05-29-2024 07:17:35 PST * Making project name in workflow files dynamic (aws#4708) * fix: Fix ci unit-tests (aws#4713) * chore(deps): bump requests from 2.31.0 to 2.32.2 in /tests/data/serve_resources/mlflow/pytorch (aws#4709) Bumps [requests](https://github.com/psf/requests) from 2.31.0 to 2.32.2. - [Release notes](https://github.com/psf/requests/releases) - [Changelog](https://github.com/psf/requests/blob/main/HISTORY.md) - [Commits](psf/requests@v2.31.0...v2.32.2) --- updated-dependencies: - dependency-name: requests dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump apache-airflow from 2.9.0 to 2.9.1 in /requirements/extras (aws#4703) * chore(deps): bump apache-airflow in /requirements/extras Bumps [apache-airflow](https://github.com/apache/airflow) from 2.9.0 to 2.9.1. - [Release notes](https://github.com/apache/airflow/releases) - [Changelog](https://github.com/apache/airflow/blob/main/RELEASE_NOTES.rst) - [Commits](apache/airflow@2.9.0...2.9.1) --- updated-dependencies: - dependency-name: apache-airflow dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> * Update tox.ini to bump apache-airflow --------- Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Kalyani Nikure <[email protected]> * chore(deps): bump mlflow from 2.10.2 to 2.12.1 in /tests/data/serve_resources/mlflow/pytorch (aws#4690) Bumps [mlflow](https://github.com/mlflow/mlflow) from 2.10.2 to 2.12.1. - [Release notes](https://github.com/mlflow/mlflow/releases) - [Changelog](https://github.com/mlflow/mlflow/blob/master/CHANGELOG.md) - [Commits](mlflow/mlflow@v2.10.2...v2.12.1) --- updated-dependencies: - dependency-name: mlflow dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump mlflow from 2.11.1 to 2.12.1 in /tests/data/serve_resources/mlflow/xgboost (aws#4692) Bumps [mlflow](https://github.com/mlflow/mlflow) from 2.11.1 to 2.12.1. - [Release notes](https://github.com/mlflow/mlflow/releases) - [Changelog](https://github.com/mlflow/mlflow/blob/master/CHANGELOG.md) - [Commits](mlflow/mlflow@v2.11.1...v2.12.1) --- updated-dependencies: - dependency-name: mlflow dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump mlflow from 2.11.1 to 2.12.1 in /tests/data/serve_resources/mlflow/tensorflow (aws#4691) Bumps [mlflow](https://github.com/mlflow/mlflow) from 2.11.1 to 2.12.1. - [Release notes](https://github.com/mlflow/mlflow/releases) - [Changelog](https://github.com/mlflow/mlflow/blob/master/CHANGELOG.md) - [Commits](mlflow/mlflow@v2.11.1...v2.12.1) --- updated-dependencies: - dependency-name: mlflow dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * change: Updates for DJL 0.28.0 release (aws#4701) * Sync Branch --------- Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: Haotian An <[email protected]> Co-authored-by: Nikhil Kulkarni <[email protected]> Co-authored-by: Nikhil Kulkarni <[email protected]> Co-authored-by: jessicazhu3 <[email protected]> Co-authored-by: Jessica Zhu <jessicazhu3@[email protected]> Co-authored-by: ci <ci> Co-authored-by: Justin <[email protected]> Co-authored-by: ASHWIN KRISHNA <[email protected]> Co-authored-by: Ashwin Krishna <[email protected]> Co-authored-by: Haixin Wang <[email protected]> Co-authored-by: Zhaoqi <[email protected]> Co-authored-by: Kalyani Nikure <[email protected]> Co-authored-by: Keerthan Vasist <[email protected]> Co-authored-by: SuhitK <[email protected]> Co-authored-by: Suhit Kodgule <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: sagemaker-bot <[email protected]> Co-authored-by: jiapinw <[email protected]> Co-authored-by: Jonathan Makunga <[email protected]> Co-authored-by: Prateek M Desai <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]> Co-authored-by: Samrudhi Sharma <[email protected]> Co-authored-by: Tom Bousso <[email protected]> Co-authored-by: Zhaoqi <[email protected]> Co-authored-by: Tyler Osterberg <[email protected]> * Merge --------- Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: Haotian An <[email protected]> Co-authored-by: Jonathan Makunga <[email protected]> Co-authored-by: evakravi <[email protected]> Co-authored-by: Erick Benitez-Ramos <[email protected]> Co-authored-by: Nikhil Kulkarni <[email protected]> Co-authored-by: Nikhil Kulkarni <[email protected]> Co-authored-by: jessicazhu3 <[email protected]> Co-authored-by: Jessica Zhu <jessicazhu3@[email protected]> Co-authored-by: Justin <[email protected]> Co-authored-by: ASHWIN KRISHNA <[email protected]> Co-authored-by: Ashwin Krishna <[email protected]> Co-authored-by: Haixin Wang <[email protected]> Co-authored-by: Zhaoqi <[email protected]> Co-authored-by: Kalyani Nikure <[email protected]> Co-authored-by: Keerthan Vasist <[email protected]> Co-authored-by: SuhitK <[email protected]> Co-authored-by: Suhit Kodgule <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: sagemaker-bot <[email protected]> Co-authored-by: jiapinw <[email protected]> Co-authored-by: Prateek M Desai <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]> Co-authored-by: Samrudhi Sharma <[email protected]> Co-authored-by: Tom Bousso <[email protected]> Co-authored-by: Zhaoqi <[email protected]> Co-authored-by: Tyler Osterberg <[email protected]> * Fix UT (aws#1465) Co-authored-by: Jonathan Makunga <[email protected]> --------- Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: Jonathan Makunga <[email protected]> Co-authored-by: Haotian An <[email protected]> Co-authored-by: Jonathan Makunga <[email protected]> Co-authored-by: evakravi <[email protected]> Co-authored-by: Erick Benitez-Ramos <[email protected]> Co-authored-by: Nikhil Kulkarni <[email protected]> Co-authored-by: Nikhil Kulkarni <[email protected]> Co-authored-by: jessicazhu3 <[email protected]> Co-authored-by: Jessica Zhu <jessicazhu3@[email protected]> Co-authored-by: Justin <[email protected]> Co-authored-by: ASHWIN KRISHNA <[email protected]> Co-authored-by: Ashwin Krishna <[email protected]> Co-authored-by: Haixin Wang <[email protected]> Co-authored-by: Zhaoqi <[email protected]> Co-authored-by: Keerthan Vasist <[email protected]> Co-authored-by: SuhitK <[email protected]> Co-authored-by: Suhit Kodgule <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: sagemaker-bot <[email protected]> Co-authored-by: jiapinw <[email protected]> Co-authored-by: Prateek M Desai <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]> Co-authored-by: Samrudhi Sharma <[email protected]> Co-authored-by: Tom Bousso <[email protected]> Co-authored-by: Zhaoqi <[email protected]> Co-authored-by: Tyler Osterberg <[email protected]>
1 parent ed43b07 commit 2c3d606

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+3900
-423
lines changed

src/sagemaker/accept_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def retrieve_default(
7777
tolerate_deprecated_model: bool = False,
7878
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
7979
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
80+
config_name: Optional[str] = None,
8081
) -> str:
8182
"""Retrieves the default accept type for the model matching the given arguments.
8283
@@ -98,6 +99,7 @@ def retrieve_default(
9899
object, used for SageMaker interactions. If not
99100
specified, one is created using the default AWS configuration
100101
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
102+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
101103
Returns:
102104
str: The default accept type to use for the model.
103105
@@ -117,4 +119,5 @@ def retrieve_default(
117119
tolerate_deprecated_model,
118120
sagemaker_session=sagemaker_session,
119121
model_type=model_type,
122+
config_name=config_name,
120123
)

src/sagemaker/content_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def retrieve_default(
7777
tolerate_deprecated_model: bool = False,
7878
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
7979
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
80+
config_name: Optional[str] = None,
8081
) -> str:
8182
"""Retrieves the default content type for the model matching the given arguments.
8283
@@ -98,6 +99,7 @@ def retrieve_default(
9899
object, used for SageMaker interactions. If not
99100
specified, one is created using the default AWS configuration
100101
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
102+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
101103
Returns:
102104
str: The default content type to use for the model.
103105
@@ -117,6 +119,7 @@ def retrieve_default(
117119
tolerate_deprecated_model,
118120
sagemaker_session=sagemaker_session,
119121
model_type=model_type,
122+
config_name=config_name,
120123
)
121124

122125

src/sagemaker/deserializers.py

+3
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def retrieve_default(
9797
tolerate_deprecated_model: bool = False,
9898
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
9999
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
100+
config_name: Optional[str] = None,
100101
) -> BaseDeserializer:
101102
"""Retrieves the default deserializer for the model matching the given arguments.
102103
@@ -118,6 +119,7 @@ def retrieve_default(
118119
object, used for SageMaker interactions. If not
119120
specified, one is created using the default AWS configuration
120121
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
122+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
121123
Returns:
122124
BaseDeserializer: The default deserializer to use for the model.
123125
@@ -138,4 +140,5 @@ def retrieve_default(
138140
tolerate_deprecated_model,
139141
sagemaker_session=sagemaker_session,
140142
model_type=model_type,
143+
config_name=config_name,
141144
)

src/sagemaker/environment_variables.py

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def retrieve_default(
3636
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3737
instance_type: Optional[str] = None,
3838
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
39+
config_name: Optional[str] = None,
3940
) -> Dict[str, str]:
4041
"""Retrieves the default container environment variables for the model matching the arguments.
4142
@@ -65,6 +66,7 @@ def retrieve_default(
6566
variables specific for the instance type.
6667
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
6768
variables.
69+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
6870
Returns:
6971
dict: The variables to use for the model.
7072
@@ -87,4 +89,5 @@ def retrieve_default(
8789
sagemaker_session=sagemaker_session,
8890
instance_type=instance_type,
8991
script=script,
92+
config_name=config_name,
9093
)

src/sagemaker/hyperparameters.py

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def retrieve_default(
3636
tolerate_vulnerable_model: bool = False,
3737
tolerate_deprecated_model: bool = False,
3838
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
39+
config_name: Optional[str] = None,
3940
) -> Dict[str, str]:
4041
"""Retrieves the default training hyperparameters for the model matching the given arguments.
4142
@@ -66,6 +67,7 @@ def retrieve_default(
6667
object, used for SageMaker interactions. If not
6768
specified, one is created using the default AWS configuration
6869
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
70+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
6971
Returns:
7072
dict: The hyperparameters to use for the model.
7173
@@ -86,6 +88,7 @@ def retrieve_default(
8688
tolerate_vulnerable_model=tolerate_vulnerable_model,
8789
tolerate_deprecated_model=tolerate_deprecated_model,
8890
sagemaker_session=sagemaker_session,
91+
config_name=config_name,
8992
)
9093

9194

src/sagemaker/image_uris.py

+3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def retrieve(
7070
inference_tool=None,
7171
serverless_inference_config=None,
7272
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
73+
config_name=None,
7374
) -> str:
7475
"""Retrieves the ECR URI for the Docker image matching the given arguments.
7576
@@ -123,6 +124,7 @@ def retrieve(
123124
object, used for SageMaker interactions. If not
124125
specified, one is created using the default AWS configuration
125126
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
127+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
126128
127129
Returns:
128130
str: The ECR URI for the corresponding SageMaker Docker image.
@@ -162,6 +164,7 @@ def retrieve(
162164
tolerate_vulnerable_model,
163165
tolerate_deprecated_model,
164166
sagemaker_session=sagemaker_session,
167+
config_name=config_name,
165168
)
166169

167170
if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):

src/sagemaker/instance_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def retrieve_default(
3636
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3737
training_instance_type: Optional[str] = None,
3838
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
39+
config_name: Optional[str] = None,
3940
) -> str:
4041
"""Retrieves the default instance type for the model matching the given arguments.
4142
@@ -64,6 +65,7 @@ def retrieve_default(
6465
Optionally supply this to get a inference instance type conditioned
6566
on the training instance, to ensure compatability of training artifact to inference
6667
instance. (Default: None).
68+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
6769
Returns:
6870
str: The default instance type to use for the model.
6971
@@ -88,6 +90,7 @@ def retrieve_default(
8890
sagemaker_session=sagemaker_session,
8991
training_instance_type=training_instance_type,
9092
model_type=model_type,
93+
config_name=config_name,
9194
)
9295

9396

src/sagemaker/jumpstart/artifacts/environment_variables.py

+7
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def _retrieve_default_environment_variables(
3939
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4040
instance_type: Optional[str] = None,
4141
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
42+
config_name: Optional[str] = None,
4243
) -> Dict[str, str]:
4344
"""Retrieves the inference environment variables for the model matching the given arguments.
4445
@@ -68,6 +69,7 @@ def _retrieve_default_environment_variables(
6869
environment variables specific for the instance type.
6970
script (JumpStartScriptScope): The JumpStart script for which to retrieve
7071
environment variables.
72+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
7173
Returns:
7274
dict: the inference environment variables to use for the model.
7375
"""
@@ -84,6 +86,7 @@ def _retrieve_default_environment_variables(
8486
tolerate_vulnerable_model=tolerate_vulnerable_model,
8587
tolerate_deprecated_model=tolerate_deprecated_model,
8688
sagemaker_session=sagemaker_session,
89+
config_name=config_name,
8790
)
8891

8992
default_environment_variables: Dict[str, str] = {}
@@ -121,6 +124,7 @@ def _retrieve_default_environment_variables(
121124
tolerate_deprecated_model=tolerate_deprecated_model,
122125
sagemaker_session=sagemaker_session,
123126
instance_type=instance_type,
127+
config_name=config_name,
124128
)
125129
)
126130

@@ -167,6 +171,7 @@ def _retrieve_gated_model_uri_env_var_value(
167171
tolerate_deprecated_model: bool = False,
168172
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
169173
instance_type: Optional[str] = None,
174+
config_name: Optional[str] = None,
170175
) -> Optional[str]:
171176
"""Retrieves the gated model env var URI matching the given arguments.
172177
@@ -190,6 +195,7 @@ def _retrieve_gated_model_uri_env_var_value(
190195
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
191196
instance_type (str): An instance type to optionally supply in order to get
192197
environment variables specific for the instance type.
198+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
193199
194200
Returns:
195201
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
@@ -211,6 +217,7 @@ def _retrieve_gated_model_uri_env_var_value(
211217
tolerate_vulnerable_model=tolerate_vulnerable_model,
212218
tolerate_deprecated_model=tolerate_deprecated_model,
213219
sagemaker_session=sagemaker_session,
220+
config_name=config_name,
214221
)
215222

216223
s3_key: Optional[str] = (

src/sagemaker/jumpstart/artifacts/hyperparameters.py

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _retrieve_default_hyperparameters(
3636
tolerate_deprecated_model: bool = False,
3737
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3838
instance_type: Optional[str] = None,
39+
config_name: Optional[str] = None,
3940
):
4041
"""Retrieves the training hyperparameters for the model matching the given arguments.
4142
@@ -66,6 +67,7 @@ def _retrieve_default_hyperparameters(
6667
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
6768
instance_type (str): An instance type to optionally supply in order to get hyperparameters
6869
specific for the instance type.
70+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
6971
Returns:
7072
dict: the hyperparameters to use for the model.
7173
"""
@@ -82,6 +84,7 @@ def _retrieve_default_hyperparameters(
8284
tolerate_vulnerable_model=tolerate_vulnerable_model,
8385
tolerate_deprecated_model=tolerate_deprecated_model,
8486
sagemaker_session=sagemaker_session,
87+
config_name=config_name,
8588
)
8689

8790
default_hyperparameters: Dict[str, str] = {}

src/sagemaker/jumpstart/artifacts/image_uris.py

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def _retrieve_image_uri(
4646
tolerate_vulnerable_model: bool = False,
4747
tolerate_deprecated_model: bool = False,
4848
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
49+
config_name: Optional[str] = None,
4950
):
5051
"""Retrieves the container image URI for JumpStart models.
5152
@@ -95,6 +96,7 @@ def _retrieve_image_uri(
9596
object, used for SageMaker interactions. If not
9697
specified, one is created using the default AWS configuration
9798
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
99+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
98100
Returns:
99101
str: the ECR URI for the corresponding SageMaker Docker image.
100102
@@ -116,6 +118,7 @@ def _retrieve_image_uri(
116118
tolerate_vulnerable_model=tolerate_vulnerable_model,
117119
tolerate_deprecated_model=tolerate_deprecated_model,
118120
sagemaker_session=sagemaker_session,
121+
config_name=config_name,
119122
)
120123

121124
if image_scope == JumpStartScriptScope.INFERENCE:
@@ -200,4 +203,5 @@ def _retrieve_image_uri(
200203
distribution=distribution,
201204
base_framework_version=base_framework_version_override or base_framework_version,
202205
training_compiler_config=training_compiler_config,
206+
config_name=config_name,
203207
)

src/sagemaker/jumpstart/artifacts/incremental_training.py

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _model_supports_incremental_training(
3333
tolerate_vulnerable_model: bool = False,
3434
tolerate_deprecated_model: bool = False,
3535
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
36+
config_name: Optional[str] = None,
3637
) -> bool:
3738
"""Returns True if the model supports incremental training.
3839
@@ -54,6 +55,7 @@ def _model_supports_incremental_training(
5455
object, used for SageMaker interactions. If not
5556
specified, one is created using the default AWS configuration
5657
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
58+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
5759
Returns:
5860
bool: the support status for incremental training.
5961
"""
@@ -70,6 +72,7 @@ def _model_supports_incremental_training(
7072
tolerate_vulnerable_model=tolerate_vulnerable_model,
7173
tolerate_deprecated_model=tolerate_deprecated_model,
7274
sagemaker_session=sagemaker_session,
75+
config_name=config_name,
7376
)
7477

7578
return model_specs.supports_incremental_training()

src/sagemaker/jumpstart/artifacts/instance_types.py

+6
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _retrieve_default_instance_type(
4040
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4141
training_instance_type: Optional[str] = None,
4242
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
43+
config_name: Optional[str] = None,
4344
) -> str:
4445
"""Retrieves the default instance type for the model.
4546
@@ -68,6 +69,7 @@ def _retrieve_default_instance_type(
6869
Optionally supply this to get a inference instance type conditioned
6970
on the training instance, to ensure compatability of training artifact to inference
7071
instance. (Default: None).
72+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
7173
Returns:
7274
str: the default instance type to use for the model or None.
7375
@@ -89,6 +91,7 @@ def _retrieve_default_instance_type(
8991
tolerate_deprecated_model=tolerate_deprecated_model,
9092
model_type=model_type,
9193
sagemaker_session=sagemaker_session,
94+
config_name=config_name,
9295
)
9396

9497
if scope == JumpStartScriptScope.INFERENCE:
@@ -128,6 +131,7 @@ def _retrieve_instance_types(
128131
tolerate_deprecated_model: bool = False,
129132
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
130133
training_instance_type: Optional[str] = None,
134+
config_name: Optional[str] = None,
131135
) -> List[str]:
132136
"""Retrieves the supported instance types for the model.
133137
@@ -156,6 +160,7 @@ def _retrieve_instance_types(
156160
Optionally supply this to get a inference instance type conditioned
157161
on the training instance, to ensure compatability of training artifact to inference
158162
instance. (Default: None).
163+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
159164
Returns:
160165
list: the supported instance types to use for the model or None.
161166
@@ -176,6 +181,7 @@ def _retrieve_instance_types(
176181
tolerate_vulnerable_model=tolerate_vulnerable_model,
177182
tolerate_deprecated_model=tolerate_deprecated_model,
178183
sagemaker_session=sagemaker_session,
184+
config_name=config_name,
179185
)
180186

181187
if scope == JumpStartScriptScope.INFERENCE:

0 commit comments

Comments
 (0)