|
19 | 19 | import re
|
20 | 20 | import sys
|
21 | 21 | import time
|
| 22 | +import typing |
22 | 23 | import warnings
|
23 | 24 | from typing import List, Dict, Any, Sequence
|
24 | 25 |
|
@@ -551,7 +552,6 @@ def train( # noqa: C901
|
551 | 552 | retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
|
552 | 553 | * max_retry_attsmpts (int): Number of times a job should be retried.
|
553 | 554 | The key in RetryStrategy is 'MaxRetryAttempts'.
|
554 |
| -
|
555 | 555 | Returns:
|
556 | 556 | str: ARN of the training job, if it is created.
|
557 | 557 | """
|
@@ -585,9 +585,13 @@ def train( # noqa: C901
|
585 | 585 | environment=environment,
|
586 | 586 | retry_strategy=retry_strategy,
|
587 | 587 | )
|
588 |
| - LOGGER.info("Creating training-job with name: %s", job_name) |
589 |
| - LOGGER.debug("train request: %s", json.dumps(train_request, indent=4)) |
590 |
| - self.sagemaker_client.create_training_job(**train_request) |
| 588 | + |
| 589 | + def submit(request): |
| 590 | + LOGGER.info("Creating training-job with name: %s", job_name) |
| 591 | + LOGGER.debug("train request: %s", json.dumps(request, indent=4)) |
| 592 | + self.sagemaker_client.create_training_job(**request) |
| 593 | + |
| 594 | + self._intercept_create_request(train_request, submit) |
591 | 595 |
|
592 | 596 | def _get_train_request( # noqa: C901
|
593 | 597 | self,
|
@@ -912,9 +916,13 @@ def process(
|
912 | 916 | tags=tags,
|
913 | 917 | experiment_config=experiment_config,
|
914 | 918 | )
|
915 |
| - LOGGER.info("Creating processing-job with name %s", job_name) |
916 |
| - LOGGER.debug("process request: %s", json.dumps(process_request, indent=4)) |
917 |
| - self.sagemaker_client.create_processing_job(**process_request) |
| 919 | + |
| 920 | + def submit(request): |
| 921 | + LOGGER.info("Creating processing-job with name %s", job_name) |
| 922 | + LOGGER.debug("process request: %s", json.dumps(request, indent=4)) |
| 923 | + self.sagemaker_client.create_processing_job(**request) |
| 924 | + |
| 925 | + self._intercept_create_request(process_request, submit) |
918 | 926 |
|
919 | 927 | def _get_process_request(
|
920 | 928 | self,
|
@@ -2086,9 +2094,12 @@ def create_tuning_job(
|
2086 | 2094 | tags=tags,
|
2087 | 2095 | )
|
2088 | 2096 |
|
2089 |
| - LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name) |
2090 |
| - LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4)) |
2091 |
| - self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request) |
| 2097 | + def submit(request): |
| 2098 | + LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name) |
| 2099 | + LOGGER.debug("tune request: %s", json.dumps(request, indent=4)) |
| 2100 | + self.sagemaker_client.create_hyper_parameter_tuning_job(**request) |
| 2101 | + |
| 2102 | + self._intercept_create_request(tune_request, submit) |
2092 | 2103 |
|
2093 | 2104 | def _get_tuning_request(
|
2094 | 2105 | self,
|
@@ -2553,9 +2564,12 @@ def transform(
|
2553 | 2564 | model_client_config=model_client_config,
|
2554 | 2565 | )
|
2555 | 2566 |
|
2556 |
| - LOGGER.info("Creating transform job with name: %s", job_name) |
2557 |
| - LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4)) |
2558 |
| - self.sagemaker_client.create_transform_job(**transform_request) |
| 2567 | + def submit(request): |
| 2568 | + LOGGER.info("Creating transform job with name: %s", job_name) |
| 2569 | + LOGGER.debug("Transform request: %s", json.dumps(request, indent=4)) |
| 2570 | + self.sagemaker_client.create_transform_job(**request) |
| 2571 | + |
| 2572 | + self._intercept_create_request(transform_request, submit) |
2559 | 2573 |
|
2560 | 2574 | def _create_model_request(
|
2561 | 2575 | self,
|
@@ -4161,6 +4175,18 @@ def account_id(self) -> str:
|
4161 | 4175 | )
|
4162 | 4176 | return sts_client.get_caller_identity()["Account"]
|
4163 | 4177 |
|
| 4178 | + def _intercept_create_request(self, request: typing.Dict, create): |
| 4179 | + """This function intercepts the create job request. |
| 4180 | +
|
| 4181 | + PipelineSession inherits this Session class and will override |
| 4182 | + this function to intercept the create request. |
| 4183 | +
|
| 4184 | + Args: |
| 4185 | + request (dict): the create job request |
| 4186 | + create (functor): a functor calls the sagemaker client create method |
| 4187 | + """ |
| 4188 | + create(request) |
| 4189 | + |
4164 | 4190 |
|
4165 | 4191 | def get_model_package_args(
|
4166 | 4192 | content_types,
|
|
0 commit comments