Skip to content

Commit 89ec027

Browse files
committed
Support intelligent parameters (#1540)
* Support intelligent parameters * fix codestyle
1 parent dfa3f0b commit 89ec027

File tree

3 files changed

+264
-34
lines changed

3 files changed

+264
-34
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import argparse
2+
import json
3+
import os
4+
import re
5+
6+
HYPERPARAMETERS_FILE_PATH = "/opt/ml/input/config/hyperparameters.json"
7+
8+
9+
def set_intelligent_params(path: str) -> None:
10+
"""
11+
Set intelligent parameters for all python files under the given path.
12+
For python code with comment sm_hyper_param or sm_hp_{variable_name}, the value will be found in
13+
/opt/ml/input/config/hyperparameters.json, and this function will rewrite lines with these comments.
14+
15+
Args:
16+
path (str): The folder path to set intellingent parameters
17+
"""
18+
with open(HYPERPARAMETERS_FILE_PATH, "r") as f:
19+
hyperparameters = json.load(f)
20+
for root, dirs, files in os.walk(path):
21+
for file in files:
22+
if file.endswith(".py"):
23+
file_path = os.path.join(root, file)
24+
rewrite_file(file_path, hyperparameters)
25+
26+
27+
def rewrite_file(file_path: str, hyperparameters: dict) -> None:
28+
"""
29+
Rewrite a single python file with intelligent parameters.
30+
31+
Args:
32+
file_path (str): The file path to rewrite
33+
hyperparameters (dict): The hyperparameter names and values
34+
"""
35+
with open(file_path, "r", encoding="utf-8") as f:
36+
lines = f.readlines()
37+
for i in range(len(lines)):
38+
lines[i] = rewrite_line(lines[i], hyperparameters)
39+
with open(file_path, "w", encoding="utf-8") as f:
40+
f.writelines(lines)
41+
42+
43+
def rewrite_line(line: str, hyperparameters: dict) -> None:
44+
"""
45+
Rewrite a single line of python code with intelligent parameters.
46+
47+
Args:
48+
line (str): The python code to rewrite
49+
hyperparameters (dict): The hyperparameter names and values
50+
"""
51+
# Remove strings from the line to avoid = and # in strings
52+
line_without_strings = re.sub(r'".*?"', '""', line.strip())
53+
line_without_strings = re.sub(r"'.*?'", '""', line_without_strings)
54+
55+
# Match lines with format "a = 1 # comment"
56+
assignment_pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*\s*=.*#.*"
57+
if re.match(assignment_pattern, line_without_strings):
58+
indent = (len(line) - len(line.lstrip())) * " "
59+
variable = line_without_strings.split("=")[0].strip()
60+
comment = line_without_strings.split("#")[-1].strip()
61+
value = get_parameter_value(variable, comment, hyperparameters)
62+
if value is None:
63+
return line
64+
if isinstance(value, str):
65+
new_line = f'{indent}{variable} = "{value}" # set by intelligent parameters\n'
66+
else:
67+
new_line = f"{indent}{variable} = {str(value)} # set by intelligent parameters\n"
68+
return new_line
69+
return line
70+
71+
72+
def get_parameter_value(variable: str, comment: str, hyperparameters: dict) -> None:
73+
"""
74+
Get the parameter value by the variable name and comment.
75+
76+
Args:
77+
variable (str): The variable name
78+
comment (str): The comment string in the python code
79+
hyperparameters (dict): The hyperparameter names and values
80+
"""
81+
if comment == "sm_hyper_param":
82+
# Get the hyperparameter value by the variable name
83+
return hyperparameters.get(variable, None)
84+
if comment.startswith("sm_hp_"):
85+
# Get the hyperparameter value by the suffix of comment
86+
return hyperparameters.get(comment[6:], None)
87+
# Get the hyperparameter value from environment variables
88+
if comment.startswith("sm_"):
89+
return os.environ.get(comment.upper(), None)
90+
return None
91+
92+
93+
if __name__ == "__main__":
94+
parser = argparse.ArgumentParser(description="Intelligent parameters")
95+
parser.add_argument(
96+
"-p", "--path", help="The folder path to set intellingent parameters", required=True
97+
)
98+
99+
args = parser.parse_args()
100+
101+
set_intelligent_params(args.path)

src/sagemaker/modules/train/local_snapshot.py

+54-34
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
1-
import boto3
2-
import docker
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Utility function to capture local environment"""
314
import logging
415
import subprocess
516
import sys
617
from typing import Optional
18+
19+
import boto3
20+
import docker
721
import yaml
822

923
logger = logging.getLogger(__name__)
@@ -66,38 +80,41 @@ def capture_local_environment(
6680
"""
6781
Capture all dependency packages installed in the local environment and build a docker image.
6882
When using this utility method, the docker daemon must be active in the environment.
69-
Please note that this is an experimental feature. This utility function is not be able to detect the package
70-
compatability between platforms. It is also not able to detect dependency conflicts between the local environment
71-
and the additional dependencies.
83+
Please note that this is an experimental feature. This utility function is not be able to
84+
detect the package compatability between platforms. It is also not able to detect dependency
85+
conflicts between the local environment and the additional dependencies.
7286
7387
Args:
7488
image_name (str): The name of the docker image.
75-
env_name (str): The name of the virtual environment to be activated in the image, defaults to "saved_local_env".
89+
env_name (str): The name of the virtual environment to be activated in the image,
90+
defaults to "saved_local_env".
7691
package_manager (str): The package manager, must be one of "conda" or "pip".
77-
deploy_to_ecr (bool): Whether to deploy the docker image to AWS ECR, defaults to False. If set to True, the AWS
78-
credentials must be configured in the environment.
79-
base_image_name (Optional[str]): If provided will be used as the base image, else the utility will evaluate
80-
from local environment in following manner:
92+
deploy_to_ecr (bool): Whether to deploy the docker image to AWS ECR, defaults to False.
93+
If set to True, the AWS credentials must be configured in the environment.
94+
base_image_name (Optional[str]): If provided will be used as the base image, else the
95+
utility will evaluate from local environment in following manner:
8196
1. If package manager is conda, it will use ubuntu:latest.
82-
2. If package manager is pip, it is resolved to base python image with the same python version
83-
as the environment running the local code.
84-
job_conda_env (Optional[str]): If set, the dependencies will be captured from this specific conda Env,
85-
otherwise the dependencies will be the installed packages in the current active environment. This parameter
86-
is only valid when the package manager is conda.
87-
additional_dependencies (Optional[str]): Either the path to a dependencies file (conda environment.yml OR pip
88-
requirements.txt file). Regardless of this setting utility will automatically generate the dependencies
89-
file corresponding to the current active environment’s snapshot. In addition to this, additional dependencies
90-
is configurable.
91-
ecr_repo_name (Optional[str]): The AWS ECR repo to push the docker image. If not specified, it will use image_name as
92-
the ECR repo name. This parameter is only valid when deploy_to_ecr is True.
93-
boto_session (Optional[boto3.Session]): The boto3 session with AWS account info. If not provided, a new boto session
94-
will be created.
97+
2. If package manager is pip, it is resolved to base python image with the same
98+
python version as the environment running the local code.
99+
job_conda_env (Optional[str]): If set, the dependencies will be captured from this specific
100+
conda Env, otherwise the dependencies will be the installed packages in the current
101+
active environment. This parameter is only valid when the package manager is conda.
102+
additional_dependencies (Optional[str]): Either the path to a dependencies file (conda
103+
environment.yml OR pip requirements.txt file). Regardless of this setting utility will
104+
automatically generate the dependencies file corresponding to the current active
105+
environment’s snapshot. In addition to this, additional dependencies is configurable.
106+
ecr_repo_name (Optional[str]): The AWS ECR repo to push the docker image. If not specified,
107+
it will use image_name as the ECR repo name. This parameter is only valid when
108+
deploy_to_ecr is True.
109+
boto_session (Optional[boto3.Session]): The boto3 session with AWS account info. If not
110+
provided, a new boto session will be created.
95111
96112
Exceptions:
97113
docker.errors.DockerException: Error while fetching server API version:
98114
The docker engine is not running in your environment.
99-
docker.errors.BuildError: The docker failed to build the image. The most likely reason is: 1) Some packages are not
100-
supported in the base image. 2) There are dependency conflicts between your local environment and additional dependencies.
115+
docker.errors.BuildError: The docker failed to build the image. The most likely reason is:
116+
1) Some packages are not supported in the base image. 2) There are dependency conflicts
117+
between your local environment and additional dependencies.
101118
botocore.exceptions.ClientError: AWS credentials are not configured.
102119
"""
103120

@@ -118,7 +135,8 @@ def capture_local_environment(
118135
".yml"
119136
) and not additional_dependencies.endswith(".txt"):
120137
raise ValueError(
121-
"When package manager is conda, additional dependencies file must be a yml file or a txt file."
138+
"When package manager is conda, additional dependencies "
139+
"file must be a yml file or a txt file."
122140
)
123141
if additional_dependencies.endswith(".yml"):
124142
_merge_environment_ymls(
@@ -153,7 +171,7 @@ def capture_local_environment(
153171
additional_requirements = f.read()
154172
with open(REQUIREMENT_TXT_PATH, "a") as f:
155173
f.write(additional_requirements)
156-
logger.info(f"Merged requirements file saved to {REQUIREMENT_TXT_PATH}")
174+
logger.info("Merged requirements file saved to %s", REQUIREMENT_TXT_PATH)
157175

158176
if not base_image_name:
159177
version = sys.version_info
@@ -165,23 +183,24 @@ def capture_local_environment(
165183

166184
else:
167185
raise ValueError(
168-
"The provided package manager is not supported. Use conda or pip as the package manager."
186+
"The provided package manager is not supported. "
187+
"Use conda or pip as the package manager."
169188
)
170189

171190
# Create the Dockerfile
172191
with open(DOCKERFILE_PATH, "w") as f:
173192
f.write(dockerfile_contents)
174193

175194
client = docker.from_env()
176-
image, logs = client.images.build(
195+
_, logs = client.images.build(
177196
path="/tmp",
178197
dockerfile=DOCKERFILE_PATH,
179198
rm=True,
180199
tag=image_name,
181200
)
182201
for log in logs:
183202
logger.info(log.get("stream", "").strip())
184-
logger.info(f"Docker image {image_name} built successfully")
203+
logger.info("Docker image %s built successfully", image_name)
185204

186205
if deploy_to_ecr:
187206
if boto_session is None:
@@ -232,14 +251,15 @@ def _merge_environment_ymls(env_name: str, env_file1: str, env_file2: str, outpu
232251
with open(output_file, "w") as f:
233252
yaml.dump(merged_env, f, sort_keys=False)
234253

235-
logger.info(f"Merged environment file saved to '{output_file}'")
254+
logger.info("Merged environment file saved to '%s'", output_file)
236255

237256

238257
def _merge_environment_yml_with_requirement_txt(
239258
env_name: str, env_file: str, req_txt: str, output_file: str
240259
):
241260
"""
242-
Merge an environment.yml file with a requirements.txt file and save to a new environment.yml file.
261+
Merge an environment.yml file with a requirements.txt file and save to a new
262+
environment.yml file.
243263
244264
Args:
245265
env_name (str): The name of the virtual environment to be activated in the image.
@@ -278,7 +298,7 @@ def _merge_environment_yml_with_requirement_txt(
278298
with open(output_file, "w") as f:
279299
yaml.dump(merged_env, f, sort_keys=False)
280300

281-
logger.info(f"Merged environment file saved to '{output_file}'")
301+
logger.info("Merged environment file saved to '%s'", output_file)
282302

283303

284304
def _push_image_to_ecr(image_name: str, ecr_repo_name: str, boto_session: Optional[boto3.Session]):
@@ -317,4 +337,4 @@ def _push_image_to_ecr(image_name: str, ecr_repo_name: str, boto_session: Option
317337
docker_push_cmd = f"docker push {ecr_image_uri}"
318338
subprocess.run(docker_push_cmd, shell=True, check=True)
319339

320-
logger.info(f"Image {image_name} pushed to {ecr_image_uri}")
340+
logger.info("Image %s pushed to %s", image_name, ecr_image_uri)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import os
2+
import pytest
3+
from sagemaker.modules.scripts.intelligent_params import (
4+
rewrite_file,
5+
rewrite_line,
6+
)
7+
8+
9+
@pytest.fixture()
10+
def hyperparameters():
11+
return {"n_estimators": 3, "epochs": 4, "state": "In Progress", "list_value": [1, 2, 3]}
12+
13+
14+
@pytest.fixture()
15+
def python_code():
16+
return """
17+
import sagemaker
18+
19+
def main(args):
20+
n_estimators = 10 # sm_hp_n_estimators
21+
state = "Not Started" # sm_hyper_param
22+
random_seed = 0.1 # sm_hyper_param
23+
24+
output_dir = "local/dir/" # sm_model_dir
25+
epochs = 5 # sm_hp_epochs
26+
input_data = [0, 0, 0] # sm_hp_list_value
27+
28+
# Load the Iris dataset
29+
iris = load_iris()
30+
y = iris.target
31+
32+
# Make predictions on the test set
33+
y_pred = clf.predict(input_data)
34+
35+
accuracy = accuracy_score(y, y_pred) # calculate the accuracy
36+
print(f"# Model accuracy: {accuracy:.2f}")
37+
"""
38+
39+
40+
@pytest.fixture()
41+
def expected_output_code():
42+
return """
43+
import sagemaker
44+
45+
def main(args):
46+
n_estimators = 3 # set by intelligent parameters
47+
state = "In Progress" # set by intelligent parameters
48+
random_seed = 0.1 # sm_hyper_param
49+
50+
output_dir = "/opt/ml/input" # set by intelligent parameters
51+
epochs = 4 # set by intelligent parameters
52+
input_data = [1, 2, 3] # set by intelligent parameters
53+
54+
# Load the Iris dataset
55+
iris = load_iris()
56+
y = iris.target
57+
58+
# Make predictions on the test set
59+
y_pred = clf.predict(input_data)
60+
61+
accuracy = accuracy_score(y, y_pred) # calculate the accuracy
62+
print(f"# Model accuracy: {accuracy:.2f}")
63+
"""
64+
65+
66+
def test_rewrite_line(hyperparameters):
67+
line = "n_estimators = 4 # sm_hyper_param"
68+
new_line = rewrite_line(line, hyperparameters)
69+
assert new_line == "n_estimators = 3 # set by intelligent parameters\n"
70+
71+
line = " epochs = 5 # sm_hp_epochs"
72+
new_line = rewrite_line(line, hyperparameters)
73+
assert new_line == " epochs = 4 # set by intelligent parameters\n"
74+
75+
os.environ["SM_MODEL_DIR"] = "/opt/ml/input"
76+
line = 'output_dir = "local/dir/" # sm_model_dir '
77+
new_line = rewrite_line(line, hyperparameters)
78+
assert new_line == 'output_dir = "/opt/ml/input" # set by intelligent parameters\n'
79+
80+
line = " random_state = 1 # not an intelligent parameter comment \n"
81+
new_line = rewrite_line(line, hyperparameters)
82+
assert new_line == " random_state = 1 # not an intelligent parameter comment \n"
83+
84+
line = "not_an_intelligent_parameter = 4 # sm_hyper_param\n"
85+
new_line = rewrite_line(line, hyperparameters)
86+
assert new_line == "not_an_intelligent_parameter = 4 # sm_hyper_param\n"
87+
88+
line = "not_found_in_hyper_params = 4 # sm_hp_not_found_in_hyper_params\n"
89+
new_line = rewrite_line(line, hyperparameters)
90+
assert new_line == "not_found_in_hyper_params = 4 # sm_hp_not_found_in_hyper_params\n"
91+
92+
line = "list_value = [4, 5, 6] # sm_hyper_param"
93+
new_line = rewrite_line(line, hyperparameters)
94+
assert new_line == "list_value = [1, 2, 3] # set by intelligent parameters\n"
95+
96+
97+
def test_rewrite_file(hyperparameters, python_code, expected_output_code):
98+
test_file_path = "temp_test.py"
99+
100+
os.environ["SM_MODEL_DIR"] = "/opt/ml/input"
101+
with open(test_file_path, "w") as f:
102+
f.write(python_code)
103+
rewrite_file(test_file_path, hyperparameters)
104+
105+
with open(test_file_path, "r") as f:
106+
new_python_code = f.read()
107+
assert new_python_code == expected_output_code
108+
109+
os.remove(test_file_path)

0 commit comments

Comments
 (0)