diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 2e6531644b..e0fc4f76b9 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -140,7 +140,14 @@ def __init__( ) def create_model( - self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs + self, + model_server_workers=None, + role=None, + vpc_config_override=VPC_CONFIG_DEFAULT, + entry_point=None, + source_dir=None, + dependencies=None, + **kwargs ): """Create a SageMaker ``SKLearnModel`` object that can be deployed to an ``Endpoint``. @@ -156,7 +163,18 @@ def create_model( the model. Default: use subnets and security groups from this Estimator. * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. - **kwargs: Passed to initialization of ``SKLearnModel``. + entry_point (str): Path (absolute or relative) to the local Python source file which + should be executed as the entry point to training. If ``source_dir`` is specified, + then ``entry_point`` must point to a file located at the root of ``source_dir``. + If not specified, the training entry point is used. + source_dir (str): Path (absolute or relative) to a directory with any other serving + source code dependencies aside from the entry point file. + If not specified, the model source directory from training is used. + dependencies (list[str]): A list of paths to directories (absolute or relative) with + any additional libraries that will be exported to the container. + If not specified, the dependencies from training are used. + **kwargs: Additional kwargs passed to the :class:`~sagemaker.sklearn.model.SKLearnModel` + constructor. Returns: sagemaker.sklearn.model.SKLearnModel: A SageMaker ``SKLearnModel`` @@ -164,17 +182,8 @@ def create_model( """ role = role or self.role - # remove unwanted entry_point kwarg - if "entry_point" in kwargs: - logger.debug("removing unused entry_point argument: %s", str(kwargs["entry_point"])) - del kwargs["entry_point"] - - # remove image kwarg - if "image" in kwargs: - image = kwargs["image"] - del kwargs["image"] - else: - image = None + if "image" not in kwargs: + kwargs["image"] = self.image_name if "enable_network_isolation" not in kwargs: kwargs["enable_network_isolation"] = self.enable_network_isolation() @@ -185,17 +194,17 @@ def create_model( return SKLearnModel( self.model_data, role, - self.entry_point, - source_dir=self._model_source_dir(), + entry_point or self.entry_point, + source_dir=(source_dir or self._model_source_dir()), enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, container_log_level=self.container_log_level, code_location=self.code_location, py_version=self.py_version, framework_version=self.framework_version, model_server_workers=model_server_workers, - image=image or self.image_name, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), + dependencies=(dependencies or self.dependencies), **kwargs ) diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index c260a8c5b8..134235d829 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -125,7 +125,14 @@ def __init__( ) def create_model( - self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs + self, + model_server_workers=None, + role=None, + vpc_config_override=VPC_CONFIG_DEFAULT, + entry_point=None, + source_dir=None, + dependencies=None, + **kwargs ): """Create a SageMaker ``XGBoostModel`` object that can be deployed to an ``Endpoint``. @@ -139,7 +146,18 @@ def create_model( Default: use subnets and security groups from this Estimator. * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. - **kwargs: Passed to initialization of ``XGBoostModel``. + entry_point (str): Path (absolute or relative) to the local Python source file which + should be executed as the entry point to training. If ``source_dir`` is specified, + then ``entry_point`` must point to a file located at the root of ``source_dir``. + If not specified, the training entry point is used. + source_dir (str): Path (absolute or relative) to a directory with any other serving + source code dependencies aside from the entry point file. + If not specified, the model source directory from training is used. + dependencies (list[str]): A list of paths to directories (absolute or relative) with + any additional libraries that will be exported to the container. + If not specified, the dependencies from training are used. + **kwargs: Additional kwargs passed to the :class:`~sagemaker.xgboost.model.XGBoostModel` + constructor. Returns: sagemaker.xgboost.model.XGBoostModel: A SageMaker ``XGBoostModel`` object. @@ -147,10 +165,8 @@ def create_model( """ role = role or self.role - # Remove unwanted entry_point kwarg - if "entry_point" in kwargs: - logger.debug("Removing unused entry_point argument: %s", str(kwargs["entry_point"])) - del kwargs["entry_point"] + if "image" not in kwargs: + kwargs["image"] = self.image_name if "name" not in kwargs: kwargs["name"] = self._current_job_name @@ -158,17 +174,17 @@ def create_model( return XGBoostModel( self.model_data, role, - self.entry_point, + entry_point or self.entry_point, framework_version=self.framework_version, - source_dir=self._model_source_dir(), + source_dir=(source_dir or self._model_source_dir()), enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, container_log_level=self.container_log_level, code_location=self.code_location, py_version=self.py_version, model_server_workers=model_server_workers, - image=self.image_name, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), + dependencies=(dependencies or self.dependencies), **kwargs ) diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 45f74f415a..3acba6b8ee 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -25,6 +25,7 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +SERVING_SCRIPT_FILE = "another_dummy_script.py" TIMESTAMP = "2017-11-06-14:14:15.672" TIME = 1507167947 BUCKET_NAME = "mybucket" @@ -249,20 +250,31 @@ def test_create_model_with_optional_params(sagemaker_session): sklearn.fit(inputs="s3://mybucket/train", job_name="new_name") + custom_image = "ubuntu:latest" new_role = "role" model_server_workers = 2 vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + new_source_dir = "s3://myotherbucket/source" + dependencies = ["/directory/a", "/directory/b"] model_name = "model-name" model = sklearn.create_model( + image=custom_image, role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config, + entry_point=SERVING_SCRIPT_FILE, + source_dir=new_source_dir, + dependencies=dependencies, name=model_name, ) + assert model.image == custom_image assert model.role == new_role assert model.model_server_workers == model_server_workers assert model.vpc_config == vpc_config + assert model.entry_point == SERVING_SCRIPT_FILE + assert model.source_dir == new_source_dir + assert model.dependencies == dependencies assert model.name == model_name diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 2e82d2556c..7aac29dbc6 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -27,6 +27,7 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +SERVING_SCRIPT_FILE = "another_dummy_script.py" TIMESTAMP = "2017-11-06-14:14:15.672" TIME = 1507167947 BUCKET_NAME = "mybucket" @@ -238,20 +239,31 @@ def test_create_model_with_optional_params(sagemaker_session): xgboost.fit(inputs="s3://mybucket/train", job_name="new_name") + custom_image = "ubuntu:latest" new_role = "role" model_server_workers = 2 vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + new_source_dir = "s3://myotherbucket/source" + dependencies = ["/directory/a", "/directory/b"] model_name = "model-name" model = xgboost.create_model( + image=custom_image, role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config, + entry_point=SERVING_SCRIPT_FILE, + source_dir=new_source_dir, + dependencies=dependencies, name=model_name, ) + assert model.image == custom_image assert model.role == new_role assert model.model_server_workers == model_server_workers assert model.vpc_config == vpc_config + assert model.entry_point == SERVING_SCRIPT_FILE + assert model.source_dir == new_source_dir + assert model.dependencies == dependencies assert model.name == model_name