Skip to content

Commit 7369e45

Browse files
committed
fix: use entry_point, source_dir, dependencies arguments during create_model for sklearn and xgboost
1 parent 18af12b commit 7369e45

File tree

4 files changed

+44
-23
lines changed

4 files changed

+44
-23
lines changed

src/sagemaker/sklearn/estimator.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,14 @@ def __init__(
140140
)
141141

142142
def create_model(
143-
self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs
143+
self,
144+
model_server_workers=None,
145+
role=None,
146+
vpc_config_override=VPC_CONFIG_DEFAULT,
147+
entry_point=None,
148+
source_dir=None,
149+
dependencies=None,
150+
**kwargs
144151
):
145152
"""Create a SageMaker ``SKLearnModel`` object that can be deployed to an
146153
``Endpoint``.
@@ -164,17 +171,8 @@ def create_model(
164171
"""
165172
role = role or self.role
166173

167-
# remove unwanted entry_point kwarg
168-
if "entry_point" in kwargs:
169-
logger.debug("removing unused entry_point argument: %s", str(kwargs["entry_point"]))
170-
del kwargs["entry_point"]
171-
172-
# remove image kwarg
173-
if "image" in kwargs:
174-
image = kwargs["image"]
175-
del kwargs["image"]
176-
else:
177-
image = None
174+
if "image" not in kwargs:
175+
kwargs["image"] = self.image_name
178176

179177
if "enable_network_isolation" not in kwargs:
180178
kwargs["enable_network_isolation"] = self.enable_network_isolation()
@@ -185,17 +183,17 @@ def create_model(
185183
return SKLearnModel(
186184
self.model_data,
187185
role,
188-
self.entry_point,
189-
source_dir=self._model_source_dir(),
186+
entry_point or self.entry_point,
187+
source_dir=(source_dir or self._model_source_dir()),
190188
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
191189
container_log_level=self.container_log_level,
192190
code_location=self.code_location,
193191
py_version=self.py_version,
194192
framework_version=self.framework_version,
195193
model_server_workers=model_server_workers,
196-
image=image or self.image_name,
197194
sagemaker_session=self.sagemaker_session,
198195
vpc_config=self.get_vpc_config(vpc_config_override),
196+
dependencies=(dependencies or self.dependencies),
199197
**kwargs
200198
)
201199

src/sagemaker/xgboost/estimator.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,14 @@ def __init__(
125125
)
126126

127127
def create_model(
128-
self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs
128+
self,
129+
model_server_workers=None,
130+
role=None,
131+
vpc_config_override=VPC_CONFIG_DEFAULT,
132+
entry_point=None,
133+
source_dir=None,
134+
dependencies=None,
135+
**kwargs
129136
):
130137
"""Create a SageMaker ``XGBoostModel`` object that can be deployed to an ``Endpoint``.
131138
@@ -147,28 +154,26 @@ def create_model(
147154
"""
148155
role = role or self.role
149156

150-
# Remove unwanted entry_point kwarg
151-
if "entry_point" in kwargs:
152-
logger.debug("Removing unused entry_point argument: %s", str(kwargs["entry_point"]))
153-
del kwargs["entry_point"]
157+
if "image" not in kwargs:
158+
kwargs["image"] = self.image_name
154159

155160
if "name" not in kwargs:
156161
kwargs["name"] = self._current_job_name
157162

158163
return XGBoostModel(
159164
self.model_data,
160165
role,
161-
self.entry_point,
166+
entry_point or self.entry_point,
162167
framework_version=self.framework_version,
163-
source_dir=self._model_source_dir(),
168+
source_dir=(source_dir or self._model_source_dir()),
164169
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
165170
container_log_level=self.container_log_level,
166171
code_location=self.code_location,
167172
py_version=self.py_version,
168173
model_server_workers=model_server_workers,
169-
image=self.image_name,
170174
sagemaker_session=self.sagemaker_session,
171175
vpc_config=self.get_vpc_config(vpc_config_override),
176+
dependencies=(dependencies or self.dependencies),
172177
**kwargs
173178
)
174179

tests/unit/test_sklearn.py

+9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
2727
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
28+
SERVING_SCRIPT_FILE = "another_dummy_script.py"
2829
TIMESTAMP = "2017-11-06-14:14:15.672"
2930
TIME = 1507167947
3031
BUCKET_NAME = "mybucket"
@@ -249,20 +250,28 @@ def test_create_model_with_optional_params(sagemaker_session):
249250

250251
sklearn.fit(inputs="s3://mybucket/train", job_name="new_name")
251252

253+
custom_image = "ubuntu:latest"
252254
new_role = "role"
253255
model_server_workers = 2
254256
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
257+
new_source_dir = "s3://myotherbucket/source"
255258
model_name = "model-name"
256259
model = sklearn.create_model(
260+
image=custom_image,
257261
role=new_role,
258262
model_server_workers=model_server_workers,
259263
vpc_config_override=vpc_config,
264+
entry_point=SERVING_SCRIPT_FILE,
265+
source_dir=new_source_dir,
260266
name=model_name,
261267
)
262268

269+
assert model.image == custom_image
263270
assert model.role == new_role
264271
assert model.model_server_workers == model_server_workers
265272
assert model.vpc_config == vpc_config
273+
assert model.entry_point == SERVING_SCRIPT_FILE
274+
assert model.source_dir == new_source_dir
266275
assert model.name == model_name
267276

268277

tests/unit/test_xgboost.py

+9
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
2929
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
30+
SERVING_SCRIPT_FILE = "another_dummy_script.py"
3031
TIMESTAMP = "2017-11-06-14:14:15.672"
3132
TIME = 1507167947
3233
BUCKET_NAME = "mybucket"
@@ -238,20 +239,28 @@ def test_create_model_with_optional_params(sagemaker_session):
238239

239240
xgboost.fit(inputs="s3://mybucket/train", job_name="new_name")
240241

242+
custom_image = "ubuntu:latest"
241243
new_role = "role"
242244
model_server_workers = 2
243245
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
246+
new_source_dir = "s3://myotherbucket/source"
244247
model_name = "model-name"
245248
model = xgboost.create_model(
249+
image=custom_image,
246250
role=new_role,
247251
model_server_workers=model_server_workers,
248252
vpc_config_override=vpc_config,
253+
entry_point=SERVING_SCRIPT_FILE,
254+
source_dir=new_source_dir,
249255
name=model_name,
250256
)
251257

258+
assert model.image == custom_image
252259
assert model.role == new_role
253260
assert model.model_server_workers == model_server_workers
254261
assert model.vpc_config == vpc_config
262+
assert model.entry_point == SERVING_SCRIPT_FILE
263+
assert model.source_dir == new_source_dir
255264
assert model.name == model_name
256265

257266

0 commit comments

Comments
 (0)