-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathserver.py
59 lines (49 loc) · 2.06 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Module for SMD Server"""
from __future__ import absolute_import
import logging
import platform
from sagemaker.serve.utils.optimize_utils import _is_s3_uri
from sagemaker.session import Session
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url
from sagemaker import fw_utils
from sagemaker.serve.utils.uploader import upload
logger = logging.getLogger(__name__)
class SageMakerSmdServer:
"""Placeholder docstring"""
def _upload_smd_artifacts(
self,
model_path: str,
sagemaker_session: Session,
secret_key: str,
s3_model_data_url: str = None,
image: str = None,
should_upload_artifacts: bool = False,
):
"""Tar the model artifact and upload to S3 bucket, then prepare for the environment variables"""
s3_upload_path = None
if _is_s3_uri(model_path):
s3_upload_path = model_path
elif should_upload_artifacts:
if s3_model_data_url:
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
else:
bucket, key_prefix = None, None
code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)
bucket, code_key_prefix = determine_bucket_and_prefix(
bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
)
logger.debug(
"Uploading the model resources to bucket=%s, key_prefix=%s.",
bucket,
code_key_prefix,
)
s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix)
logger.debug("Model resources uploaded to: %s", s3_upload_path)
env_vars = {
"SAGEMAKER_INFERENCE_CODE_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_INFERENCE_CODE": "inference.handler",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars