Skip to content

Commit d570911

Browse files
authored
Processing Jobs Python SDK support (aws#225)
This change adds support for Amazon SageMaker Processing jobs. New classes include Processor, ScriptModeProcessor, SKLearnProcessor, SparkMLJavaProcessor, SparkMLPythonProcessor, ProcessingJob, FileInput, FileOutput, S3Uploader, and S3Downloader.
1 parent b0201c5 commit d570911

23 files changed

+33224
-24
lines changed

src/sagemaker/fw_utils.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import re
2020
import shutil
2121
import tempfile
22-
from six.moves.urllib.parse import urlparse
2322

2423
import sagemaker.utils
2524
from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN
25+
from sagemaker import s3
2626

2727
_TAR_SOURCE_FILENAME = "source.tar.gz"
2828

@@ -447,18 +447,17 @@ def framework_version_from_tag(image_tag):
447447

448448

449449
def parse_s3_url(url):
450-
"""Returns an (s3 bucket, key name/prefix) tuple from a url with an s3
451-
scheme
450+
"""Calls the method with the same name in the s3 module.
451+
452+
:func:~sagemaker.s3.parse_s3_url
453+
452454
Args:
453-
url (str):
454-
Returns:
455-
tuple: A tuple containing:
456-
str: S3 bucket name str: S3 key
455+
url: A URL, expected with an s3 scheme.
456+
457+
Returns: The return value of s3.parse_s3_url, which is a tuple containing:
458+
str: S3 bucket name str: S3 key
457459
"""
458-
parsed_url = urlparse(url)
459-
if parsed_url.scheme != "s3":
460-
raise ValueError("Expecting 's3' scheme, got: {} in {}".format(parsed_url.scheme, url))
461-
return parsed_url.netloc, parsed_url.path.lstrip("/")
460+
return s3.parse_s3_url(url)
462461

463462

464463
def model_code_key_prefix(code_location_key_prefix, model_name, image):

src/sagemaker/job.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ def start_new(self, estimator, inputs):
5959
def wait(self):
6060
"""Wait for the Amazon SageMaker job to finish."""
6161

62+
@abstractmethod
63+
def describe(self):
64+
"""Describe the job."""
65+
66+
@abstractmethod
67+
def stop(self):
68+
"""Stop the job."""
69+
6270
@staticmethod
6371
def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
6472
"""

src/sagemaker/network.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This file contains code related to network configuration, including
14+
encryption, network isolation, and VPC configurations.
15+
"""
16+
from __future__ import absolute_import
17+
18+
19+
class NetworkConfig(object):
20+
"""Accepts network configuration parameters and provides a method to turn these parameters
21+
into a dictionary."""
22+
23+
def __init__(
24+
self,
25+
enable_network_isolation=False,
26+
encrypt_inter_container_traffic=False,
27+
security_group_ids=None,
28+
subnets=None,
29+
):
30+
"""Initialize a ``NetworkConfig`` instance. NetworkConfig accepts network configuration
31+
parameters and provides a method to turn these parameters into a dictionary.
32+
33+
Args:
34+
enable_network_isolation (bool): Boolean that determines whether to enable
35+
network isolation.
36+
encrypt_inter_container_traffic (bool): Boolean that determines whether to
37+
encrypt inter-container traffic.
38+
security_group_ids ([str]): A list of strings representing security group IDs.
39+
subnets ([str]): A list of strings representing subnets.
40+
"""
41+
self.enable_network_isolation = enable_network_isolation
42+
self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
43+
self.security_group_ids = security_group_ids
44+
self.subnets = subnets
45+
46+
def to_request_dict(self):
47+
"""Generates a request dictionary using the parameters provided to the class."""
48+
network_config_request = {
49+
"EnableInterContainerTrafficEncryption": self.encrypt_inter_container_traffic,
50+
"EnableNetworkIsolation": self.enable_network_isolation,
51+
"VpcConfig": {"SecurityGroupIds": self.security_group_ids, "Subnets": self.subnets},
52+
}
53+
54+
return network_config_request

0 commit comments

Comments
 (0)