Skip to content

fix: explicitly handle arguments in create_model for sklearn and xgboost #1535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,14 @@ def __init__(
)

def create_model(
self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs
self,
model_server_workers=None,
role=None,
vpc_config_override=VPC_CONFIG_DEFAULT,
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a SageMaker ``SKLearnModel`` object that can be deployed to an
``Endpoint``.
Expand All @@ -156,25 +163,27 @@ def create_model(
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Passed to initialization of ``SKLearnModel``.
entry_point (str): Path (absolute or relative) to the local Python source file which
should be executed as the entry point to training. If ``source_dir`` is specified,
then ``entry_point`` must point to a file located at the root of ``source_dir``.
If not specified, the training entry point is used.
source_dir (str): Path (absolute or relative) to a directory with any other serving
source code dependencies aside from the entry point file.
If not specified, the model source directory from training is used.
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container.
If not specified, the dependencies from training are used.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.sklearn.model.SKLearnModel`
constructor.

Returns:
sagemaker.sklearn.model.SKLearnModel: A SageMaker ``SKLearnModel``
object. See :func:`~sagemaker.sklearn.model.SKLearnModel` for full details.
"""
role = role or self.role

# remove unwanted entry_point kwarg
if "entry_point" in kwargs:
logger.debug("removing unused entry_point argument: %s", str(kwargs["entry_point"]))
del kwargs["entry_point"]

# remove image kwarg
if "image" in kwargs:
image = kwargs["image"]
del kwargs["image"]
else:
image = None
if "image" not in kwargs:
kwargs["image"] = self.image_name

if "enable_network_isolation" not in kwargs:
kwargs["enable_network_isolation"] = self.enable_network_isolation()
Expand All @@ -185,17 +194,17 @@ def create_model(
return SKLearnModel(
self.model_data,
role,
self.entry_point,
source_dir=self._model_source_dir(),
entry_point or self.entry_point,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
image=image or self.image_name,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
dependencies=(dependencies or self.dependencies),
**kwargs
)

Expand Down
34 changes: 25 additions & 9 deletions src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,14 @@ def __init__(
)

def create_model(
self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs
self,
model_server_workers=None,
role=None,
vpc_config_override=VPC_CONFIG_DEFAULT,
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a SageMaker ``XGBoostModel`` object that can be deployed to an ``Endpoint``.

Expand All @@ -139,36 +146,45 @@ def create_model(
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Passed to initialization of ``XGBoostModel``.
entry_point (str): Path (absolute or relative) to the local Python source file which
should be executed as the entry point to training. If ``source_dir`` is specified,
then ``entry_point`` must point to a file located at the root of ``source_dir``.
If not specified, the training entry point is used.
source_dir (str): Path (absolute or relative) to a directory with any other serving
source code dependencies aside from the entry point file.
If not specified, the model source directory from training is used.
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container.
If not specified, the dependencies from training are used.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.xgboost.model.XGBoostModel`
constructor.

Returns:
sagemaker.xgboost.model.XGBoostModel: A SageMaker ``XGBoostModel`` object.
See :func:`~sagemaker.xgboost.model.XGBoostModel` for full details.
"""
role = role or self.role

# Remove unwanted entry_point kwarg
if "entry_point" in kwargs:
logger.debug("Removing unused entry_point argument: %s", str(kwargs["entry_point"]))
del kwargs["entry_point"]
if "image" not in kwargs:
kwargs["image"] = self.image_name

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

return XGBoostModel(
self.model_data,
role,
self.entry_point,
entry_point or self.entry_point,
framework_version=self.framework_version,
source_dir=self._model_source_dir(),
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
model_server_workers=model_server_workers,
image=self.image_name,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
dependencies=(dependencies or self.dependencies),
**kwargs
)

Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
SERVING_SCRIPT_FILE = "another_dummy_script.py"
TIMESTAMP = "2017-11-06-14:14:15.672"
TIME = 1507167947
BUCKET_NAME = "mybucket"
Expand Down Expand Up @@ -249,20 +250,31 @@ def test_create_model_with_optional_params(sagemaker_session):

sklearn.fit(inputs="s3://mybucket/train", job_name="new_name")

custom_image = "ubuntu:latest"
new_role = "role"
model_server_workers = 2
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
new_source_dir = "s3://myotherbucket/source"
dependencies = ["/directory/a", "/directory/b"]
model_name = "model-name"
model = sklearn.create_model(
image=custom_image,
role=new_role,
model_server_workers=model_server_workers,
vpc_config_override=vpc_config,
entry_point=SERVING_SCRIPT_FILE,
source_dir=new_source_dir,
dependencies=dependencies,
name=model_name,
)

assert model.image == custom_image
assert model.role == new_role
assert model.model_server_workers == model_server_workers
assert model.vpc_config == vpc_config
assert model.entry_point == SERVING_SCRIPT_FILE
assert model.source_dir == new_source_dir
assert model.dependencies == dependencies
assert model.name == model_name


Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
SERVING_SCRIPT_FILE = "another_dummy_script.py"
TIMESTAMP = "2017-11-06-14:14:15.672"
TIME = 1507167947
BUCKET_NAME = "mybucket"
Expand Down Expand Up @@ -238,20 +239,31 @@ def test_create_model_with_optional_params(sagemaker_session):

xgboost.fit(inputs="s3://mybucket/train", job_name="new_name")

custom_image = "ubuntu:latest"
new_role = "role"
model_server_workers = 2
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
new_source_dir = "s3://myotherbucket/source"
dependencies = ["/directory/a", "/directory/b"]
model_name = "model-name"
model = xgboost.create_model(
image=custom_image,
role=new_role,
model_server_workers=model_server_workers,
vpc_config_override=vpc_config,
entry_point=SERVING_SCRIPT_FILE,
source_dir=new_source_dir,
dependencies=dependencies,
name=model_name,
)

assert model.image == custom_image
assert model.role == new_role
assert model.model_server_workers == model_server_workers
assert model.vpc_config == vpc_config
assert model.entry_point == SERVING_SCRIPT_FILE
assert model.source_dir == new_source_dir
assert model.dependencies == dependencies
assert model.name == model_name


Expand Down