|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific
|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | """This module contains code related to SKLearn Processors which are used for Processing jobs.
|
14 |
| -
|
15 | 14 | These jobs let customers perform data pre-processing, post-processing, feature engineering,
|
16 | 15 | data validation, and model evaluation and interpretation on SageMaker.
|
17 | 16 | """
|
18 | 17 | from __future__ import absolute_import
|
19 | 18 |
|
20 |
| -from sagemaker.processing import FrameworkProcessor |
21 |
| -from sagemaker.sklearn.estimator import SKLearn |
22 |
| - |
23 |
| - |
24 |
| -class SKLearnProcessor(FrameworkProcessor): |
25 |
| - """Initialize an ``SKLearnProcessor`` instance. |
26 |
| -
|
27 |
| - The SKLearnProcessor handles Amazon SageMaker processing tasks for jobs using scikit-learn. |
28 |
| -
|
29 |
| - Unless ``image_uri`` is specified, the scikit-learn environment is an |
30 |
| - Amazon-built Docker container that executes functions defined in the supplied |
31 |
| - ``code`` Python script. |
32 |
| -
|
33 |
| - The arguments have the exact same meaning as in ``FrameworkProcessor``. |
34 |
| -
|
35 |
| - .. tip:: |
| 19 | +from sagemaker import image_uris, Session |
| 20 | +from sagemaker.processing import ScriptProcessor |
| 21 | +from sagemaker.sklearn import defaults |
36 | 22 |
|
37 |
| - You can find additional parameters for initializing this class at |
38 |
| - :class:`~sagemaker.processing.FrameworkProcessor`. |
39 |
| - """ |
40 | 23 |
|
41 |
| - estimator_cls = SKLearn |
| 24 | +class SKLearnProcessor(ScriptProcessor): |
| 25 | + """Handles Amazon SageMaker processing tasks for jobs using scikit-learn.""" |
42 | 26 |
|
43 | 27 | def __init__(
|
44 | 28 | self,
|
45 |
| - framework_version, # New arg |
| 29 | + framework_version, |
46 | 30 | role,
|
47 |
| - instance_count, |
48 | 31 | instance_type,
|
49 |
| - py_version="py3", # New kwarg |
50 |
| - image_uri=None, |
| 32 | + instance_count, |
51 | 33 | command=None,
|
52 | 34 | volume_size_in_gb=30,
|
53 | 35 | volume_kms_key=None,
|
54 | 36 | output_kms_key=None,
|
55 |
| - code_location=None, # New arg |
56 | 37 | max_runtime_in_seconds=None,
|
57 | 38 | base_job_name=None,
|
58 | 39 | sagemaker_session=None,
|
59 | 40 | env=None,
|
60 | 41 | tags=None,
|
61 | 42 | network_config=None,
|
62 | 43 | ):
|
63 |
| - """This processor executes a Python script in a scikit-learn execution environment.""" |
64 |
| - super().__init__( |
65 |
| - self.estimator_cls, |
66 |
| - framework_version, |
67 |
| - role, |
68 |
| - instance_count, |
69 |
| - instance_type, |
70 |
| - py_version, |
71 |
| - image_uri, |
72 |
| - command, |
73 |
| - volume_size_in_gb, |
74 |
| - volume_kms_key, |
75 |
| - output_kms_key, |
76 |
| - code_location, |
77 |
| - max_runtime_in_seconds, |
78 |
| - base_job_name, |
79 |
| - sagemaker_session, |
80 |
| - env, |
81 |
| - tags, |
82 |
| - network_config, |
| 44 | + """Initialize an ``SKLearnProcessor`` instance. |
| 45 | + The SKLearnProcessor handles Amazon SageMaker processing tasks for jobs using scikit-learn. |
| 46 | + Args: |
| 47 | + framework_version (str): The version of scikit-learn. |
| 48 | + role (str): An AWS IAM role name or ARN. The Amazon SageMaker training jobs |
| 49 | + and APIs that create Amazon SageMaker endpoints use this role |
| 50 | + to access training data and model artifacts. After the endpoint |
| 51 | + is created, the inference code might use the IAM role, if it |
| 52 | + needs to access an AWS resource. |
| 53 | + instance_type (str): Type of EC2 instance to use for |
| 54 | + processing, for example, 'ml.c4.xlarge'. |
| 55 | + instance_count (int): The number of instances to run |
| 56 | + the Processing job with. Defaults to 1. |
| 57 | + command ([str]): The command to run, along with any command-line flags. |
| 58 | + Example: ["python3", "-v"]. If not provided, ["python3"] or ["python2"] |
| 59 | + will be chosen based on the py_version parameter. |
| 60 | + volume_size_in_gb (int): Size in GB of the EBS volume to |
| 61 | + use for storing data during processing (default: 30). |
| 62 | + volume_kms_key (str): A KMS key for the processing |
| 63 | + volume. |
| 64 | + output_kms_key (str): The KMS key id for all ProcessingOutputs. |
| 65 | + max_runtime_in_seconds (int): Timeout in seconds. |
| 66 | + After this amount of time Amazon SageMaker terminates the job |
| 67 | + regardless of its current status. |
| 68 | + base_job_name (str): Prefix for processing name. If not specified, |
| 69 | + the processor generates a default job name, based on the |
| 70 | + training image name and current timestamp. |
| 71 | + sagemaker_session (sagemaker.session.Session): Session object which |
| 72 | + manages interactions with Amazon SageMaker APIs and any other |
| 73 | + AWS services needed. If not specified, the processor creates one |
| 74 | + using the default AWS configuration chain. |
| 75 | + env (dict): Environment variables to be passed to the processing job. |
| 76 | + tags ([dict]): List of tags to be passed to the processing job. |
| 77 | + network_config (sagemaker.network.NetworkConfig): A NetworkConfig |
| 78 | + object that configures network isolation, encryption of |
| 79 | + inter-container traffic, security group IDs, and subnets. |
| 80 | + """ |
| 81 | + if not command: |
| 82 | + command = ["python3"] |
| 83 | + |
| 84 | + session = sagemaker_session or Session() |
| 85 | + region = session.boto_region_name |
| 86 | + |
| 87 | + image_uri = image_uris.retrieve( |
| 88 | + defaults.SKLEARN_NAME, region, version=framework_version, instance_type=instance_type |
| 89 | + ) |
| 90 | + |
| 91 | + super(SKLearnProcessor, self).__init__( |
| 92 | + role=role, |
| 93 | + image_uri=image_uri, |
| 94 | + instance_count=instance_count, |
| 95 | + instance_type=instance_type, |
| 96 | + command=command, |
| 97 | + volume_size_in_gb=volume_size_in_gb, |
| 98 | + volume_kms_key=volume_kms_key, |
| 99 | + output_kms_key=output_kms_key, |
| 100 | + max_runtime_in_seconds=max_runtime_in_seconds, |
| 101 | + base_job_name=base_job_name, |
| 102 | + sagemaker_session=session, |
| 103 | + env=env, |
| 104 | + tags=tags, |
| 105 | + network_config=network_config, |
83 | 106 | )
|
0 commit comments