diff --git a/src/sagemaker/model_monitor/data_capture_config.py b/src/sagemaker/model_monitor/data_capture_config.py index f8aa88091a..22502ab50c 100644 --- a/src/sagemaker/model_monitor/data_capture_config.py +++ b/src/sagemaker/model_monitor/data_capture_config.py @@ -41,6 +41,7 @@ def __init__( capture_options=None, csv_content_types=None, json_content_types=None, + sagemaker_session=None, ): """Initialize a DataCaptureConfig object for capturing data from Amazon SageMaker Endpoints. @@ -56,14 +57,21 @@ def __init__( which data to capture between request and response. csv_content_types ([str]): Optional. Default=["text/csv"]. json_content_types([str]): Optional. Default=["application/json"]. - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (default: None). If not + specified, one is created using the default AWS configuration + chain. """ self.enable_capture = enable_capture self.sampling_percentage = sampling_percentage self.destination_s3_uri = destination_s3_uri if self.destination_s3_uri is None: + sagemaker_session = sagemaker_session or Session() self.destination_s3_uri = os.path.join( - "s3://", Session().default_bucket(), _MODEL_MONITOR_S3_PATH, _DATA_CAPTURE_S3_PATH + "s3://", + sagemaker_session.default_bucket(), + _MODEL_MONITOR_S3_PATH, + _DATA_CAPTURE_S3_PATH, ) self.kms_key_id = kms_key_id diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 9de8642252..80da8a551c 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -192,14 +192,22 @@ def enable_data_capture(self): to enable data capture. For a more customized experience, refer to update_data_capture_config, instead. """ - self.update_data_capture_config(data_capture_config=DataCaptureConfig(enable_capture=True)) + self.update_data_capture_config( + data_capture_config=DataCaptureConfig( + enable_capture=True, sagemaker_session=self.sagemaker_session + ) + ) def disable_data_capture(self): """Updates the DataCaptureConfig for the Predictor's associated Amazon SageMaker Endpoint to disable data capture. For a more customized experience, refer to update_data_capture_config, instead. """ - self.update_data_capture_config(data_capture_config=DataCaptureConfig(enable_capture=False)) + self.update_data_capture_config( + data_capture_config=DataCaptureConfig( + enable_capture=False, sagemaker_session=self.sagemaker_session + ) + ) def update_data_capture_config(self, data_capture_config): """Updates the DataCaptureConfig for the Predictor's associated Amazon SageMaker Endpoint diff --git a/tests/integ/test_data_capture_config.py b/tests/integ/test_data_capture_config.py index b1780b7a40..afa8d23ebd 100644 --- a/tests/integ/test_data_capture_config.py +++ b/tests/integ/test_data_capture_config.py @@ -126,6 +126,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status( capture_options=CUSTOM_CAPTURE_OPTIONS, csv_content_types=CUSTOM_CSV_CONTENT_TYPES, json_content_types=CUSTOM_JSON_CONTENT_TYPES, + sagemaker_session=sagemaker_session, ), ) @@ -224,6 +225,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status( capture_options=CUSTOM_CAPTURE_OPTIONS, csv_content_types=CUSTOM_CSV_CONTENT_TYPES, json_content_types=CUSTOM_JSON_CONTENT_TYPES, + sagemaker_session=sagemaker_session, ) ) diff --git a/tests/integ/test_model_monitor.py b/tests/integ/test_model_monitor.py index 2703e197b6..cdbc6e9b7a 100644 --- a/tests/integ/test_model_monitor.py +++ b/tests/integ/test_model_monitor.py @@ -107,7 +107,7 @@ def predictor(sagemaker_session, tf_full_version): INSTANCE_COUNT, INSTANCE_TYPE, endpoint_name=endpoint_name, - data_capture_config=DataCaptureConfig(True), + data_capture_config=DataCaptureConfig(True, sagemaker_session=sagemaker_session), ) yield predictor diff --git a/tests/unit/sagemaker/monitor/test_data_capture_config.py b/tests/unit/sagemaker/monitor/test_data_capture_config.py index 327b72858d..a27f4171c3 100644 --- a/tests/unit/sagemaker/monitor/test_data_capture_config.py +++ b/tests/unit/sagemaker/monitor/test_data_capture_config.py @@ -12,12 +12,14 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +from mock import Mock + from sagemaker.model_monitor import DataCaptureConfig DEFAULT_ENABLE_CAPTURE = True DEFAULT_SAMPLING_PERCENTAGE = 20 DEFAULT_BUCKET_NAME = "default-bucket" -DEFAULT_DESTINATION_S3_URI = "s3://" + DEFAULT_BUCKET_NAME + "/model-monitor/data-capture" +DEFAULT_DESTINATION_S3_URI = "s3://{}/model-monitor/data-capture".format(DEFAULT_BUCKET_NAME) DEFAULT_KMS_KEY_ID = None DEFAULT_CAPTURE_MODES = ["REQUEST", "RESPONSE"] DEFAULT_CSV_CONTENT_TYPES = ["text/csv"] @@ -33,7 +35,7 @@ NON_DEFAULT_JSON_CONTENT_TYPES = ["custom/json-format"] -def test_to_request_dict_returns_correct_params_when_non_defaults_provided(): +def test_init_when_non_defaults_provided(): data_capture_config = DataCaptureConfig( enable_capture=NON_DEFAULT_ENABLE_CAPTURE, sampling_percentage=NON_DEFAULT_SAMPLING_PERCENTAGE, @@ -51,9 +53,12 @@ def test_to_request_dict_returns_correct_params_when_non_defaults_provided(): assert data_capture_config.json_content_types == NON_DEFAULT_JSON_CONTENT_TYPES -def test_to_request_dict_returns_correct_default_params_when_optionals_not_provided(): +def test_init_when_optionals_not_provided(): + sagemaker_session = Mock() + sagemaker_session.default_bucket.return_value = DEFAULT_BUCKET_NAME + data_capture_config = DataCaptureConfig( - enable_capture=DEFAULT_ENABLE_CAPTURE, destination_s3_uri=DEFAULT_DESTINATION_S3_URI + enable_capture=DEFAULT_ENABLE_CAPTURE, sagemaker_session=sagemaker_session ) assert data_capture_config.enable_capture == DEFAULT_ENABLE_CAPTURE