Skip to content

Commit 7ba36b7

Browse files
authored
Merge branch 'master' into tgi-optimum-0.0.18
2 parents f9d53a0 + 84989bb commit 7ba36b7

File tree

18 files changed

+460
-16
lines changed

18 files changed

+460
-16
lines changed

MANIFEST.in

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
recursive-include src/sagemaker *.py
22

33
include src/sagemaker/image_uri_config/*.json
4+
include src/sagemaker/serve/schema/*.json
45
include src/sagemaker/serve/requirements.txt
56
recursive-include requirements *
67

src/sagemaker/local/image.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import sagemaker.local.data
4141
import sagemaker.local.utils
4242
import sagemaker.utils
43+
from sagemaker.utils import check_tarfile_data_filter_attribute
4344

4445
CONTAINER_PREFIX = "algo"
4546
STUDIO_HOST_NAME = "sagemaker-local"
@@ -686,7 +687,8 @@ def _prepare_serving_volumes(self, model_location):
686687
for filename in model_data_source.get_file_list():
687688
if tarfile.is_tarfile(filename):
688689
with tarfile.open(filename) as tar:
689-
tar.extractall(path=model_data_source.get_root_dir())
690+
check_tarfile_data_filter_attribute()
691+
tar.extractall(path=model_data_source.get_root_dir(), filter="data")
690692

691693
volumes.append(_Volume(model_data_source.get_root_dir(), "/opt/ml/model"))
692694

src/sagemaker/serve/builder/model_builder.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from sagemaker.predictor import Predictor
3939
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
4040
from sagemaker.serve.spec.inference_spec import InferenceSpec
41+
from sagemaker.serve.utils import task
42+
from sagemaker.serve.utils.exceptions import TaskNotFoundException
4143
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
4244
from sagemaker.serve.detector.image_detector import (
4345
auto_detect_container,
@@ -605,7 +607,7 @@ def build(
605607

606608
self.serve_settings = self._get_serve_setting()
607609

608-
self._is_custom_image_uri = self.image_uri is None
610+
self._is_custom_image_uri = self.image_uri is not None
609611

610612
if isinstance(self.model, str):
611613
if self._is_jumpstart_model_id():
@@ -616,7 +618,12 @@ def build(
616618
hf_model_md = get_huggingface_model_metadata(
617619
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
618620
)
619-
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705
621+
622+
model_task = hf_model_md.get("pipeline_tag")
623+
if self.schema_builder is None and model_task:
624+
self._schema_builder_init(model_task)
625+
626+
if model_task == "text-generation": # pylint: disable=R1705
620627
return self._build_for_tgi()
621628
else:
622629
return self._build_for_transformers()
@@ -674,3 +681,18 @@ def validate(self, model_dir: str) -> Type[bool]:
674681
"""
675682

676683
return get_metadata(model_dir)
684+
685+
def _schema_builder_init(self, model_task: str):
686+
"""Initialize the schema builder
687+
688+
Args:
689+
model_task (str): Required, the task name
690+
691+
Raises:
692+
TaskNotFoundException: If the I/O schema for the given task is not found.
693+
"""
694+
try:
695+
sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task)
696+
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
697+
except ValueError:
698+
raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.")

src/sagemaker/serve/model_server/djl_serving/prepare.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import List
2121
from pathlib import Path
2222

23-
from sagemaker.utils import _tmpdir
23+
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
2424
from sagemaker.s3 import S3Downloader
2525
from sagemaker.djl_inference import DJLModel
2626
from sagemaker.djl_inference.model import _read_existing_serving_properties
@@ -53,7 +53,8 @@ def _extract_js_resource(js_model_dir: str, js_id: str):
5353
"""Uncompress the jumpstart resource"""
5454
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
5555
with tarfile.open(str(tmp_sourcedir)) as resources:
56-
resources.extractall(path=js_model_dir)
56+
check_tarfile_data_filter_attribute()
57+
resources.extractall(path=js_model_dir, filter="data")
5758

5859

5960
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path):

src/sagemaker/serve/model_server/tgi/prepare.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pathlib import Path
2020

2121
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
22-
from sagemaker.utils import _tmpdir
22+
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
2323
from sagemaker.s3 import S3Downloader
2424

2525
logger = logging.getLogger(__name__)
@@ -29,7 +29,8 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
2929
"""Uncompress the jumpstart resource"""
3030
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
3131
with tarfile.open(str(tmp_sourcedir)) as resources:
32-
resources.extractall(path=code_dir)
32+
check_tarfile_data_filter_attribute()
33+
resources.extractall(path=code_dir, filter="data")
3334

3435

3536
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:

src/sagemaker/serve/schema/task.json

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
{
2+
"fill-mask": {
3+
"sample_inputs": {
4+
"properties": {
5+
"inputs": "Paris is the <mask> of France.",
6+
"parameters": {}
7+
}
8+
},
9+
"sample_outputs": {
10+
"properties": [
11+
{
12+
"sequence": "Paris is the capital of France.",
13+
"score": 0.7
14+
}
15+
]
16+
}
17+
},
18+
"question-answering": {
19+
"sample_inputs": {
20+
"properties": {
21+
"context": "I have a German Shepherd dog, named Coco.",
22+
"question": "What is my dog's breed?"
23+
}
24+
},
25+
"sample_outputs": {
26+
"properties": [
27+
{
28+
"answer": "German Shepherd",
29+
"score": 0.972,
30+
"start": 9,
31+
"end": 24
32+
}
33+
]
34+
}
35+
},
36+
"text-classification": {
37+
"sample_inputs": {
38+
"properties": {
39+
"inputs": "Where is the capital of France?, Paris is the capital of France.",
40+
"parameters": {}
41+
}
42+
},
43+
"sample_outputs": {
44+
"properties": [
45+
{
46+
"label": "entailment",
47+
"score": 0.997
48+
}
49+
]
50+
}
51+
},
52+
"text-generation": {
53+
"sample_inputs": {
54+
"properties": {
55+
"inputs": "Hello, I'm a language model",
56+
"parameters": {}
57+
}
58+
},
59+
"sample_outputs": {
60+
"properties": [
61+
{
62+
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
63+
}
64+
]
65+
}
66+
}
67+
}

src/sagemaker/serve/utils/exceptions.py

+9
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,12 @@ class SkipTuningComboException(ModelBuilderException):
6060

6161
def __init__(self, message):
6262
super().__init__(message=message)
63+
64+
65+
class TaskNotFoundException(ModelBuilderException):
66+
"""Raise when HuggingFace task could not be found"""
67+
68+
fmt = "Error Message: {message}"
69+
70+
def __init__(self, message):
71+
super().__init__(message=message)

src/sagemaker/serve/utils/task.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve task fallback input/output schema"""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import os
18+
from typing import Any, Tuple
19+
20+
21+
def retrieve_local_schemas(task: str) -> Tuple[Any, Any]:
22+
"""Retrieves task sample inputs and outputs locally.
23+
24+
Args:
25+
task (str): Required, the task name
26+
27+
Returns:
28+
Tuple[Any, Any]: A tuple that contains the sample input,
29+
at index 0, and output schema, at index 1.
30+
31+
Raises:
32+
ValueError: If no tasks config found or the task does not exist in the local config.
33+
"""
34+
config_dir = os.path.dirname(os.path.dirname(__file__))
35+
task_io_config_path = os.path.join(config_dir, "schema", "task.json")
36+
try:
37+
with open(task_io_config_path) as f:
38+
task_io_config = json.load(f)
39+
task_io_schemas = task_io_config.get(task, None)
40+
41+
if task_io_schemas is None:
42+
raise ValueError(f"Could not find {task} I/O schema.")
43+
44+
sample_schema = (
45+
task_io_schemas["sample_inputs"]["properties"],
46+
task_io_schemas["sample_outputs"]["properties"],
47+
)
48+
return sample_schema
49+
except FileNotFoundError:
50+
raise ValueError("Could not find tasks config file.")

src/sagemaker/utils.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import random
2323
import re
2424
import shutil
25+
import sys
2526
import tarfile
2627
import tempfile
2728
import time
@@ -591,7 +592,8 @@ def _create_or_update_code_dir(
591592
download_file_from_url(source_directory, local_code_path, sagemaker_session)
592593

593594
with tarfile.open(name=local_code_path, mode="r:gz") as t:
594-
t.extractall(path=code_dir)
595+
check_tarfile_data_filter_attribute()
596+
t.extractall(path=code_dir, filter="data")
595597

596598
elif source_directory:
597599
if os.path.exists(code_dir):
@@ -628,7 +630,8 @@ def _extract_model(model_uri, sagemaker_session, tmp):
628630
else:
629631
local_model_path = model_uri.replace("file://", "")
630632
with tarfile.open(name=local_model_path, mode="r:gz") as t:
631-
t.extractall(path=tmp_model_dir)
633+
check_tarfile_data_filter_attribute()
634+
t.extractall(path=tmp_model_dir, filter="data")
632635
return tmp_model_dir
633636

634637

@@ -1489,3 +1492,25 @@ def format_tags(tags: Tags) -> List[TagsDict]:
14891492
return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()]
14901493

14911494
return tags
1495+
1496+
1497+
class PythonVersionError(Exception):
1498+
"""Raise when a secure [/patched] version of Python is not used."""
1499+
1500+
1501+
def check_tarfile_data_filter_attribute():
1502+
"""Check if tarfile has data_filter utility.
1503+
1504+
Tarfile-data_filter utility has guardrails against untrusted de-serialisation.
1505+
1506+
Raises:
1507+
PythonVersionError: if `tarfile.data_filter` is not available.
1508+
"""
1509+
# The function and it's usages can be deprecated post support of python >= 3.12
1510+
if not hasattr(tarfile, "data_filter"):
1511+
raise PythonVersionError(
1512+
f"Since tarfile extraction is unsafe the operation is prohibited "
1513+
f"per PEP-721. Please update your Python [{sys.version}] "
1514+
f"to latest patch [refer to https://www.python.org/downloads/] "
1515+
f"to consume the security patch"
1516+
)

src/sagemaker/workflow/_utils.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
Step,
3333
ConfigurableRetryStep,
3434
)
35-
from sagemaker.utils import _save_model, download_file_from_url, format_tags
35+
from sagemaker.utils import (
36+
_save_model,
37+
download_file_from_url,
38+
format_tags,
39+
check_tarfile_data_filter_attribute,
40+
)
3641
from sagemaker.workflow.retry import RetryPolicy
3742
from sagemaker.workflow.utilities import trim_request_dict
3843

@@ -257,7 +262,8 @@ def _inject_repack_script_and_launcher(self):
257262
download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session)
258263

259264
with tarfile.open(name=old_targz_path, mode="r:gz") as t:
260-
t.extractall(path=targz_contents_dir)
265+
check_tarfile_data_filter_attribute()
266+
t.extractall(path=targz_contents_dir, filter="data")
261267

262268
shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT))
263269
with open(

tests/integ/s3_utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import boto3
2020
from six.moves.urllib.parse import urlparse
2121

22+
from sagemaker.utils import check_tarfile_data_filter_attribute
23+
2224

2325
def assert_s3_files_exist(sagemaker_session, s3_url, files):
2426
parsed_url = urlparse(s3_url)
@@ -55,4 +57,5 @@ def extract_files_from_s3(s3_url, tmpdir, sagemaker_session):
5557
s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model)
5658

5759
with tarfile.open(model, "r") as tar_file:
58-
tar_file.extractall(tmpdir)
60+
check_tarfile_data_filter_attribute()
61+
tar_file.extractall(tmpdir, filter="data")

0 commit comments

Comments
 (0)