Skip to content

Commit 4014a62

Browse files
authored
Merge pull request #4 from aws/master
update with aws:master
2 parents 20c2def + 425390e commit 4014a62

15 files changed

+682
-335
lines changed

src/sagemaker/local/local_session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def __init__(self, boto_session=None):
379379
if platform.system() == "Windows":
380380
logger.warning("Windows Support for Local Mode is Experimental")
381381

382-
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
382+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
383383
"""Initialize this Local SageMaker Session.
384384
385385
Args:
@@ -413,6 +413,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
413413

414414
self.config = yaml.load(open(sagemaker_config_file, "r"))
415415

416+
self._default_bucket = None
417+
self._desired_default_bucket_name = default_bucket
418+
416419
def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
417420
"""
418421

src/sagemaker/processing.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,9 @@ def run(
369369
"""
370370
self._current_job_name = self._generate_current_job_name(job_name=job_name)
371371

372-
user_script_name = self._get_user_script_name(code)
373-
user_code_s3_uri = self._upload_code(code)
372+
user_code_s3_uri = self._handle_user_code_url(code)
373+
user_script_name = self._get_user_code_name(code)
374+
374375
inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri)
375376

376377
self._set_entrypoint(self.command, user_script_name)
@@ -389,25 +390,59 @@ def run(
389390
if wait:
390391
self.latest_job.wait(logs=logs)
391392

392-
def _get_user_script_name(self, code):
393-
"""Finds the user script name using the provided code file,
394-
directory, or script name.
393+
def _get_user_code_name(self, code):
394+
"""Gets the basename of the user's code from the URL the customer provided.
395395
396396
Args:
397-
code (str): This can be an S3 uri or a local path to either
398-
a directory or a file.
397+
code (str): A URL to the user's code.
398+
399+
Returns:
400+
str: The basename of the user's code.
401+
402+
"""
403+
code_url = urlparse(code)
404+
return os.path.basename(code_url.path)
405+
406+
def _handle_user_code_url(self, code):
407+
"""Gets the S3 URL containing the user's code.
408+
409+
Inspects the scheme the customer passed in ("s3://" for code in S3, "file://" or nothing
410+
for absolute or local file paths. Uploads the code to S3 if the code is a local file.
411+
412+
Args:
413+
code (str): A URL to the customer's code.
399414
400415
Returns:
401-
str: The script name from the S3 uri or from the file found
402-
on the user's local machine.
416+
str: The S3 URL to the customer's code.
417+
403418
"""
404-
if os.path.isdir(code) is None or not os.path.splitext(code)[1]:
419+
code_url = urlparse(code)
420+
if code_url.scheme == "s3":
421+
user_code_s3_uri = code
422+
elif code_url.scheme == "" or code_url.scheme == "file":
423+
# Validate that the file exists locally and is not a directory.
424+
if not os.path.exists(code):
425+
raise ValueError(
426+
"""code {} wasn't found. Please make sure that the file exists.
427+
""".format(
428+
code
429+
)
430+
)
431+
if not os.path.isfile(code):
432+
raise ValueError(
433+
"""code {} must be a file, not a directory. Please pass a path to a file.
434+
""".format(
435+
code
436+
)
437+
)
438+
user_code_s3_uri = self._upload_code(code)
439+
else:
405440
raise ValueError(
406-
"""'code' must be a file, not a directory. Please pass a path to a file, not a
407-
directory.
408-
"""
441+
"code {} url scheme {} is not recognized. Please pass a file path or S3 url".format(
442+
code, code_url.scheme
443+
)
409444
)
410-
return os.path.basename(code)
445+
return user_code_s3_uri
411446

412447
def _upload_code(self, code):
413448
"""Uploads a code file or directory specified as a string

src/sagemaker/session.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,13 @@ class Session(object): # pylint: disable=too-many-public-methods
7676
bucket based on a naming convention which includes the current AWS account ID.
7777
"""
7878

79-
def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None):
79+
def __init__(
80+
self,
81+
boto_session=None,
82+
sagemaker_client=None,
83+
sagemaker_runtime_client=None,
84+
default_bucket=None,
85+
):
8086
"""Initialize a SageMaker ``Session``.
8187
8288
Args:
@@ -91,15 +97,23 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
9197
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
9298
using this ``Session`` use this client. If not provided, one will be created using
9399
this instance's ``boto_session``.
100+
default_bucket (str): The default s3 bucket to be used by this session.
101+
Ex: "sagemaker-us-west-2"
102+
94103
"""
95104
self._default_bucket = None
96105

97106
# currently is used for local_code in local mode
98107
self.config = None
99108

100-
self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
109+
self._initialize(
110+
boto_session=boto_session,
111+
sagemaker_client=sagemaker_client,
112+
sagemaker_runtime_client=sagemaker_runtime_client,
113+
default_bucket=default_bucket,
114+
)
101115

102-
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
116+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
103117
"""Initialize this SageMaker Session.
104118
105119
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
@@ -126,6 +140,12 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
126140

127141
prepend_user_agent(self.sagemaker_runtime_client)
128142

143+
self._default_bucket = None
144+
self._desired_default_bucket_name = default_bucket
145+
146+
# Create default bucket on session init to verify that desired name, if specified, is valid
147+
self.default_bucket()
148+
129149
self.local_mode = False
130150

131151
@property
@@ -314,11 +334,14 @@ def default_bucket(self):
314334
if self._default_bucket:
315335
return self._default_bucket
316336

337+
default_bucket = self._desired_default_bucket_name
317338
region = self.boto_session.region_name
318-
account = self.boto_session.client(
319-
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
320-
).get_caller_identity()["Account"]
321-
default_bucket = "sagemaker-{}-{}".format(region, account)
339+
340+
if not default_bucket:
341+
account = self.boto_session.client(
342+
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
343+
).get_caller_identity()["Account"]
344+
default_bucket = "sagemaker-{}-{}".format(region, account)
322345

323346
s3 = self.boto_session.resource("s3")
324347
try:

src/sagemaker/sklearn/processing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
class SKLearnProcessor(ScriptProcessor):
2727
"""Handles Amazon SageMaker processing tasks for jobs using scikit-learn."""
2828

29+
_valid_framework_versions = ["0.20.0"]
30+
2931
def __init__(
3032
self,
3133
framework_version,
@@ -84,6 +86,13 @@ def __init__(
8486
session = sagemaker_session or Session()
8587
region = session.boto_region_name
8688

89+
if framework_version not in self._valid_framework_versions:
90+
raise ValueError(
91+
"scikit-learn version {} is not supported. Supported versions are {}".format(
92+
framework_version, self._valid_framework_versions
93+
)
94+
)
95+
8796
if not command:
8897
command = ["python3"]
8998

tests/integ/test_local_mode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LocalNoS3Session(LocalSession):
4343
def __init__(self):
4444
super(LocalSession, self).__init__()
4545

46-
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
46+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
4747
self.boto_session = boto3.Session(region_name=DEFAULT_REGION)
4848
if self.config is None:
4949
self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}}
@@ -53,6 +53,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
5353
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
5454
self.local_mode = True
5555

56+
self._default_bucket = None
57+
self._desired_default_bucket_name = default_bucket
58+
5659

5760
@pytest.fixture(scope="module")
5861
def mxnet_model(sagemaker_local_session, mxnet_full_version):

tests/integ/test_processing.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
import os
1616

17+
import boto3
1718
import pytest
19+
from botocore.config import Config
20+
from sagemaker import Session
1821
from sagemaker.fw_registry import default_framework_uri
1922

2023
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor
@@ -23,6 +26,35 @@
2326
from tests.integ.kms_utils import get_or_create_kms_key
2427

2528
ROLE = "SageMakerRole"
29+
DEFAULT_REGION = "us-west-2"
30+
CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket"
31+
32+
33+
@pytest.fixture(scope="module")
34+
def sagemaker_session_with_custom_bucket(
35+
boto_config, sagemaker_client_config, sagemaker_runtime_config
36+
):
37+
boto_session = (
38+
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
39+
)
40+
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
41+
sagemaker_client = (
42+
boto_session.client("sagemaker", **sagemaker_client_config)
43+
if sagemaker_client_config
44+
else None
45+
)
46+
runtime_client = (
47+
boto_session.client("sagemaker-runtime", **sagemaker_runtime_config)
48+
if sagemaker_runtime_config
49+
else None
50+
)
51+
52+
return Session(
53+
boto_session=boto_session,
54+
sagemaker_client=sagemaker_client,
55+
sagemaker_runtime_client=runtime_client,
56+
default_bucket=CUSTOM_BUCKET_PATH,
57+
)
2658

2759

2860
@pytest.fixture(scope="module")
@@ -170,6 +202,90 @@ def test_sklearn_with_customizations(
170202
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
171203

172204

205+
def test_sklearn_with_custom_default_bucket(
206+
sagemaker_session_with_custom_bucket,
207+
image_uri,
208+
sklearn_full_version,
209+
cpu_instance_type,
210+
output_kms_key,
211+
):
212+
213+
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
214+
215+
sklearn_processor = SKLearnProcessor(
216+
framework_version=sklearn_full_version,
217+
role=ROLE,
218+
command=["python3"],
219+
instance_type=cpu_instance_type,
220+
instance_count=1,
221+
volume_size_in_gb=100,
222+
volume_kms_key=None,
223+
output_kms_key=output_kms_key,
224+
max_runtime_in_seconds=3600,
225+
base_job_name="test-sklearn-with-customizations",
226+
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
227+
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
228+
sagemaker_session=sagemaker_session_with_custom_bucket,
229+
)
230+
231+
sklearn_processor.run(
232+
code=os.path.join(DATA_DIR, "dummy_script.py"),
233+
inputs=[
234+
ProcessingInput(
235+
source=input_file_path,
236+
destination="/opt/ml/processing/input/container/path/",
237+
input_name="dummy_input",
238+
s3_data_type="S3Prefix",
239+
s3_input_mode="File",
240+
s3_data_distribution_type="FullyReplicated",
241+
s3_compression_type="None",
242+
)
243+
],
244+
outputs=[
245+
ProcessingOutput(
246+
source="/opt/ml/processing/output/container/path/",
247+
output_name="dummy_output",
248+
s3_upload_mode="EndOfJob",
249+
)
250+
],
251+
arguments=["-v"],
252+
wait=True,
253+
logs=True,
254+
)
255+
256+
job_description = sklearn_processor.latest_job.describe()
257+
258+
assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
259+
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
260+
261+
assert job_description["ProcessingInputs"][1]["InputName"] == "code"
262+
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]
263+
264+
assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations")
265+
266+
assert job_description["ProcessingJobStatus"] == "Completed"
267+
268+
assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key
269+
assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output"
270+
271+
assert job_description["ProcessingResources"] == {
272+
"ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100}
273+
}
274+
275+
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
276+
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
277+
"python3",
278+
"/opt/ml/processing/input/code/dummy_script.py",
279+
]
280+
assert job_description["AppSpecification"]["ImageUri"] == image_uri
281+
282+
assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}
283+
284+
assert ROLE in job_description["RoleArn"]
285+
286+
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
287+
288+
173289
def test_sklearn_with_no_inputs_or_outputs(
174290
sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type
175291
):

tests/unit/test_create_deploy_entities.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@
3434
@pytest.fixture()
3535
def sagemaker_session():
3636
boto_mock = Mock(name="boto_session", region_name=REGION)
37+
client_mock = Mock()
38+
client_mock.get_caller_identity.return_value = {
39+
"UserId": "mock_user_id",
40+
"Account": "012345678910",
41+
"Arn": "arn:aws:iam::012345678910:user/mock-user",
42+
}
43+
boto_mock.client.return_value = client_mock
3744
ims = sagemaker.Session(boto_session=boto_mock)
3845
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
3946
return ims

0 commit comments

Comments
 (0)