|
22 | 22 |
|
23 | 23 | from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
|
24 | 24 | from sagemaker.model_monitor import DataCaptureConfig
|
25 |
| -from sagemaker.session import Session |
| 25 | +from sagemaker.session import production_variant, Session |
26 | 26 | from sagemaker.utils import name_from_base
|
27 | 27 |
|
28 | 28 | from sagemaker.model_monitor.model_monitoring import (
|
@@ -157,6 +157,106 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
|
157 | 157 | args["Body"] = data
|
158 | 158 | return args
|
159 | 159 |
|
| 160 | + def update_endpoint( |
| 161 | + self, |
| 162 | + initial_instance_count=None, |
| 163 | + instance_type=None, |
| 164 | + accelerator_type=None, |
| 165 | + model_name=None, |
| 166 | + tags=None, |
| 167 | + kms_key=None, |
| 168 | + data_capture_config_dict=None, |
| 169 | + wait=True, |
| 170 | + ): |
| 171 | + """Update the existing endpoint with the provided attributes. |
| 172 | +
|
| 173 | + This creates a new EndpointConfig in the process. If ``initial_instance_count``, |
| 174 | + ``instance_type``, ``accelerator_type``, or ``model_name`` is specified, then a new |
| 175 | + ``ProductionVariant`` configuration is created; values from the existing configuration |
| 176 | + are not preserved if any of those parameters are specified. |
| 177 | +
|
| 178 | + Args: |
| 179 | + initial_instance_count (int): The initial number of instances to run in the endpoint. |
| 180 | + This is required if ``instance_type``, ``accelerator_type``, or ``model_name`` is |
| 181 | + specified. Otherwise, the values from the existing endpoint configuration's |
| 182 | + ``ProductionVariant``s are used. |
| 183 | + instance_type (str): The EC2 instance type to deploy the endpoint to. |
| 184 | + This is required if ``initial_instance_count`` or ``accelerator_type`` is specified. |
| 185 | + Otherwise, the values from the existing endpoint configuration's |
| 186 | + ``ProductionVariant``s are used. |
| 187 | + accelerator_type (str): The type of Elastic Inference accelerator to attach to |
| 188 | + the endpoint, e.g. 'ml.eia1.medium'. If not specified, and |
| 189 | + ``initial_instance_count``, ``instance_type``, and ``model_name`` are also ``None``, |
| 190 | + the values from the existing endpoint configuration's ``ProductionVariant``s are |
| 191 | + used. Otherwise, no Elastic Inference accelerator is attached to the endpoint. |
| 192 | + model_name (str): The name of the model to be associated with the endpoint. |
| 193 | + This is required if ``initial_instance_count``, ``instance_type``, or |
| 194 | + ``accelerator_type`` is specified and if there is more than one model associated |
| 195 | + with the endpoint. Otherwise, the existing model for the endpoint is used. |
| 196 | + tags (list[dict[str, str]]): The list of tags to add to the endpoint |
| 197 | + config. If not specified, the tags of the existing endpoint configuration are used. |
| 198 | + If any of the existing tags are reserved AWS ones (i.e. begin with "aws"), |
| 199 | + they are not carried over to the new endpoint configuration. |
| 200 | + kms_key (str): The KMS key that is used to encrypt the data on the storage volume |
| 201 | + attached to the instance hosting the endpoint If not specified, |
| 202 | + the KMS key of the existing endpoint configuration is used. |
| 203 | + data_capture_config_dict (dict): The endpoint data capture configuration |
| 204 | + for use with Amazon SageMaker Model Monitoring. If not specified, |
| 205 | + the data capture configuration of the existing endpoint configuration is used. |
| 206 | +
|
| 207 | + Raises: |
| 208 | + ValueError: If there is not enough information to create a new ``ProductionVariant``: |
| 209 | +
|
| 210 | + - If ``initial_instance_count``, ``accelerator_type``, or ``model_name`` is |
| 211 | + specified, but ``instance_type`` is ``None``. |
| 212 | + - If ``initial_instance_count``, ``instance_type``, or ``accelerator_type`` is |
| 213 | + specified and either ``model_name`` is ``None`` or there are multiple models |
| 214 | + associated with the endpoint. |
| 215 | + """ |
| 216 | + production_variants = None |
| 217 | + |
| 218 | + if initial_instance_count or instance_type or accelerator_type or model_name: |
| 219 | + if instance_type is None or initial_instance_count is None: |
| 220 | + raise ValueError( |
| 221 | + "Missing initial_instance_count and/or instance_type. Provided values: " |
| 222 | + "initial_instance_count={}, instance_type={}, accelerator_type={}, " |
| 223 | + "model_name={}.".format( |
| 224 | + initial_instance_count, instance_type, accelerator_type, model_name |
| 225 | + ) |
| 226 | + ) |
| 227 | + |
| 228 | + if model_name is None: |
| 229 | + if len(self._model_names) > 1: |
| 230 | + raise ValueError( |
| 231 | + "Unable to choose a default model for a new EndpointConfig because " |
| 232 | + "the endpoint has multiple models: {}".format(", ".join(self._model_names)) |
| 233 | + ) |
| 234 | + model_name = self._model_names[0] |
| 235 | + else: |
| 236 | + self._model_names = [model_name] |
| 237 | + |
| 238 | + production_variant_config = production_variant( |
| 239 | + model_name, |
| 240 | + instance_type, |
| 241 | + initial_instance_count=initial_instance_count, |
| 242 | + accelerator_type=accelerator_type, |
| 243 | + ) |
| 244 | + production_variants = [production_variant_config] |
| 245 | + |
| 246 | + new_endpoint_config_name = name_from_base(self._endpoint_config_name) |
| 247 | + self.sagemaker_session.create_endpoint_config_from_existing( |
| 248 | + self._endpoint_config_name, |
| 249 | + new_endpoint_config_name, |
| 250 | + new_tags=tags, |
| 251 | + new_kms_key=kms_key, |
| 252 | + new_data_capture_config_dict=data_capture_config_dict, |
| 253 | + new_production_variants=production_variants, |
| 254 | + ) |
| 255 | + self.sagemaker_session.update_endpoint( |
| 256 | + self.endpoint_name, new_endpoint_config_name, wait=wait |
| 257 | + ) |
| 258 | + self._endpoint_config_name = new_endpoint_config_name |
| 259 | + |
160 | 260 | def _delete_endpoint_config(self):
|
161 | 261 | """Delete the Amazon SageMaker endpoint configuration"""
|
162 | 262 | self.sagemaker_session.delete_endpoint_config(self._endpoint_config_name)
|
|
0 commit comments