Skip to content

Commit 0f06897

Browse files
cj-zhangJoseph Zhang
authored andcommitted
Remove main function entrypoint in ModelBuilder dependency manager. (aws#5058)
* Remove main function entrypoint in ModelBuilder dependency manager. * Remove main function entrypoint in ModelBuilder dependency manager. --------- Co-authored-by: Joseph Zhang <[email protected]>
1 parent ed058b8 commit 0f06897

File tree

2 files changed

+18
-36
lines changed

2 files changed

+18
-36
lines changed

src/sagemaker/serve/detector/dependency_manager.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,34 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool =
3434
"""Placeholder docstring"""
3535
path = work_dir.joinpath("requirements.txt")
3636
if "auto" in dependencies and dependencies["auto"]:
37+
import site
38+
39+
pkl_path = work_dir.joinpath(PKL_FILE_NAME)
40+
dest_path = path
41+
site_packages_dir = site.getsitepackages()[0]
42+
pickle_command_dir = "/sagemaker/serve/detector"
43+
3744
command = [
3845
sys.executable,
39-
Path(__file__).parent.joinpath("pickle_dependencies.py"),
40-
"--pkl_path",
41-
work_dir.joinpath(PKL_FILE_NAME),
42-
"--dest",
43-
path,
46+
"-c",
4447
]
4548

4649
if capture_all:
47-
command.append("--capture_all")
50+
command.append(
51+
f"from pickle_dependencies import get_all_requirements;"
52+
f'get_all_requirements("{dest_path}")'
53+
)
54+
else:
55+
command.append(
56+
f"from pickle_dependencies import get_requirements_for_pkl_file;"
57+
f'get_requirements_for_pkl_file("{pkl_path}", "{dest_path}")'
58+
)
4859

4960
subprocess.run(
5061
command,
5162
env={"SETUPTOOLS_USE_DISTUTILS": "stdlib"},
5263
check=True,
64+
cwd=site_packages_dir + pickle_command_dir,
5365
)
5466

5567
with open(path, "r") as f:

src/sagemaker/serve/detector/pickle_dependencies.py

-30
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import absolute_import
44
from pathlib import Path
55
from typing import List
6-
import argparse
76
import email.parser
87
import email.policy
98
import json
@@ -129,32 +128,3 @@ def get_all_requirements(dest: Path):
129128
version = package_info.get("version")
130129

131130
out.write(f"{name}=={version}\n")
132-
133-
134-
def parse_args():
135-
"""Placeholder docstring"""
136-
parser = argparse.ArgumentParser(
137-
prog="pkl_requirements", description="Generates a requirements.txt for a cloudpickle file"
138-
)
139-
parser.add_argument("--pkl_path", required=True, help="path of the pkl file")
140-
parser.add_argument("--dest", required=True, help="path of the destination requirements.txt")
141-
parser.add_argument(
142-
"--capture_all",
143-
action="store_true",
144-
help="capture all dependencies in current environment",
145-
)
146-
args = parser.parse_args()
147-
return (Path(args.pkl_path), Path(args.dest), args.capture_all)
148-
149-
150-
def main():
151-
"""Placeholder docstring"""
152-
pkl_path, dest, capture_all = parse_args()
153-
if capture_all:
154-
get_all_requirements(dest)
155-
else:
156-
get_requirements_for_pkl_file(pkl_path, dest)
157-
158-
159-
if __name__ == "__main__":
160-
main()

0 commit comments

Comments
 (0)