Skip to content

Commit fe8f673

Browse files
authored
fix: explicitly handle arguments in create_model for sklearn and xgboost (#1535)
1 parent 18af12b commit fe8f673

File tree

4 files changed

+74
-25
lines changed

4 files changed

+74
-25
lines changed

src/sagemaker/sklearn/estimator.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,14 @@ def __init__(
140140
)
141141

142142
def create_model(
143-
self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs
143+
self,
144+
model_server_workers=None,
145+
role=None,
146+
vpc_config_override=VPC_CONFIG_DEFAULT,
147+
entry_point=None,
148+
source_dir=None,
149+
dependencies=None,
150+
**kwargs
144151
):
145152
"""Create a SageMaker ``SKLearnModel`` object that can be deployed to an
146153
``Endpoint``.
@@ -156,25 +163,27 @@ def create_model(
156163
the model. Default: use subnets and security groups from this Estimator.
157164
* 'Subnets' (list[str]): List of subnet ids.
158165
* 'SecurityGroupIds' (list[str]): List of security group ids.
159-
**kwargs: Passed to initialization of ``SKLearnModel``.
166+
entry_point (str): Path (absolute or relative) to the local Python source file which
167+
should be executed as the entry point to training. If ``source_dir`` is specified,
168+
then ``entry_point`` must point to a file located at the root of ``source_dir``.
169+
If not specified, the training entry point is used.
170+
source_dir (str): Path (absolute or relative) to a directory with any other serving
171+
source code dependencies aside from the entry point file.
172+
If not specified, the model source directory from training is used.
173+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
174+
any additional libraries that will be exported to the container.
175+
If not specified, the dependencies from training are used.
176+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.sklearn.model.SKLearnModel`
177+
constructor.
160178
161179
Returns:
162180
sagemaker.sklearn.model.SKLearnModel: A SageMaker ``SKLearnModel``
163181
object. See :func:`~sagemaker.sklearn.model.SKLearnModel` for full details.
164182
"""
165183
role = role or self.role
166184

167-
# remove unwanted entry_point kwarg
168-
if "entry_point" in kwargs:
169-
logger.debug("removing unused entry_point argument: %s", str(kwargs["entry_point"]))
170-
del kwargs["entry_point"]
171-
172-
# remove image kwarg
173-
if "image" in kwargs:
174-
image = kwargs["image"]
175-
del kwargs["image"]
176-
else:
177-
image = None
185+
if "image" not in kwargs:
186+
kwargs["image"] = self.image_name
178187

179188
if "enable_network_isolation" not in kwargs:
180189
kwargs["enable_network_isolation"] = self.enable_network_isolation()
@@ -185,17 +194,17 @@ def create_model(
185194
return SKLearnModel(
186195
self.model_data,
187196
role,
188-
self.entry_point,
189-
source_dir=self._model_source_dir(),
197+
entry_point or self.entry_point,
198+
source_dir=(source_dir or self._model_source_dir()),
190199
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
191200
container_log_level=self.container_log_level,
192201
code_location=self.code_location,
193202
py_version=self.py_version,
194203
framework_version=self.framework_version,
195204
model_server_workers=model_server_workers,
196-
image=image or self.image_name,
197205
sagemaker_session=self.sagemaker_session,
198206
vpc_config=self.get_vpc_config(vpc_config_override),
207+
dependencies=(dependencies or self.dependencies),
199208
**kwargs
200209
)
201210

src/sagemaker/xgboost/estimator.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,14 @@ def __init__(
125125
)
126126

127127
def create_model(
128-
self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs
128+
self,
129+
model_server_workers=None,
130+
role=None,
131+
vpc_config_override=VPC_CONFIG_DEFAULT,
132+
entry_point=None,
133+
source_dir=None,
134+
dependencies=None,
135+
**kwargs
129136
):
130137
"""Create a SageMaker ``XGBoostModel`` object that can be deployed to an ``Endpoint``.
131138
@@ -139,36 +146,45 @@ def create_model(
139146
Default: use subnets and security groups from this Estimator.
140147
* 'Subnets' (list[str]): List of subnet ids.
141148
* 'SecurityGroupIds' (list[str]): List of security group ids.
142-
**kwargs: Passed to initialization of ``XGBoostModel``.
149+
entry_point (str): Path (absolute or relative) to the local Python source file which
150+
should be executed as the entry point to training. If ``source_dir`` is specified,
151+
then ``entry_point`` must point to a file located at the root of ``source_dir``.
152+
If not specified, the training entry point is used.
153+
source_dir (str): Path (absolute or relative) to a directory with any other serving
154+
source code dependencies aside from the entry point file.
155+
If not specified, the model source directory from training is used.
156+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
157+
any additional libraries that will be exported to the container.
158+
If not specified, the dependencies from training are used.
159+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.xgboost.model.XGBoostModel`
160+
constructor.
143161
144162
Returns:
145163
sagemaker.xgboost.model.XGBoostModel: A SageMaker ``XGBoostModel`` object.
146164
See :func:`~sagemaker.xgboost.model.XGBoostModel` for full details.
147165
"""
148166
role = role or self.role
149167

150-
# Remove unwanted entry_point kwarg
151-
if "entry_point" in kwargs:
152-
logger.debug("Removing unused entry_point argument: %s", str(kwargs["entry_point"]))
153-
del kwargs["entry_point"]
168+
if "image" not in kwargs:
169+
kwargs["image"] = self.image_name
154170

155171
if "name" not in kwargs:
156172
kwargs["name"] = self._current_job_name
157173

158174
return XGBoostModel(
159175
self.model_data,
160176
role,
161-
self.entry_point,
177+
entry_point or self.entry_point,
162178
framework_version=self.framework_version,
163-
source_dir=self._model_source_dir(),
179+
source_dir=(source_dir or self._model_source_dir()),
164180
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
165181
container_log_level=self.container_log_level,
166182
code_location=self.code_location,
167183
py_version=self.py_version,
168184
model_server_workers=model_server_workers,
169-
image=self.image_name,
170185
sagemaker_session=self.sagemaker_session,
171186
vpc_config=self.get_vpc_config(vpc_config_override),
187+
dependencies=(dependencies or self.dependencies),
172188
**kwargs
173189
)
174190

tests/unit/test_sklearn.py

+12
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
2727
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
28+
SERVING_SCRIPT_FILE = "another_dummy_script.py"
2829
TIMESTAMP = "2017-11-06-14:14:15.672"
2930
TIME = 1507167947
3031
BUCKET_NAME = "mybucket"
@@ -249,20 +250,31 @@ def test_create_model_with_optional_params(sagemaker_session):
249250

250251
sklearn.fit(inputs="s3://mybucket/train", job_name="new_name")
251252

253+
custom_image = "ubuntu:latest"
252254
new_role = "role"
253255
model_server_workers = 2
254256
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
257+
new_source_dir = "s3://myotherbucket/source"
258+
dependencies = ["/directory/a", "/directory/b"]
255259
model_name = "model-name"
256260
model = sklearn.create_model(
261+
image=custom_image,
257262
role=new_role,
258263
model_server_workers=model_server_workers,
259264
vpc_config_override=vpc_config,
265+
entry_point=SERVING_SCRIPT_FILE,
266+
source_dir=new_source_dir,
267+
dependencies=dependencies,
260268
name=model_name,
261269
)
262270

271+
assert model.image == custom_image
263272
assert model.role == new_role
264273
assert model.model_server_workers == model_server_workers
265274
assert model.vpc_config == vpc_config
275+
assert model.entry_point == SERVING_SCRIPT_FILE
276+
assert model.source_dir == new_source_dir
277+
assert model.dependencies == dependencies
266278
assert model.name == model_name
267279

268280

tests/unit/test_xgboost.py

+12
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
2929
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
30+
SERVING_SCRIPT_FILE = "another_dummy_script.py"
3031
TIMESTAMP = "2017-11-06-14:14:15.672"
3132
TIME = 1507167947
3233
BUCKET_NAME = "mybucket"
@@ -238,20 +239,31 @@ def test_create_model_with_optional_params(sagemaker_session):
238239

239240
xgboost.fit(inputs="s3://mybucket/train", job_name="new_name")
240241

242+
custom_image = "ubuntu:latest"
241243
new_role = "role"
242244
model_server_workers = 2
243245
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
246+
new_source_dir = "s3://myotherbucket/source"
247+
dependencies = ["/directory/a", "/directory/b"]
244248
model_name = "model-name"
245249
model = xgboost.create_model(
250+
image=custom_image,
246251
role=new_role,
247252
model_server_workers=model_server_workers,
248253
vpc_config_override=vpc_config,
254+
entry_point=SERVING_SCRIPT_FILE,
255+
source_dir=new_source_dir,
256+
dependencies=dependencies,
249257
name=model_name,
250258
)
251259

260+
assert model.image == custom_image
252261
assert model.role == new_role
253262
assert model.model_server_workers == model_server_workers
254263
assert model.vpc_config == vpc_config
264+
assert model.entry_point == SERVING_SCRIPT_FILE
265+
assert model.source_dir == new_source_dir
266+
assert model.dependencies == dependencies
255267
assert model.name == model_name
256268

257269

0 commit comments

Comments
 (0)