Skip to content

Commit 84989bb

Browse files
fix: add fixes for tarfile extractall functionality PEP-721 (aws#4441)
* fix: add fixes for tarfile extractall functionality PEP-721
1 parent 4db9aba commit 84989bb

File tree

10 files changed

+67
-14
lines changed

10 files changed

+67
-14
lines changed

src/sagemaker/local/image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import sagemaker.local.data
4141
import sagemaker.local.utils
4242
import sagemaker.utils
43+
from sagemaker.utils import check_tarfile_data_filter_attribute
4344

4445
CONTAINER_PREFIX = "algo"
4546
STUDIO_HOST_NAME = "sagemaker-local"
@@ -686,7 +687,8 @@ def _prepare_serving_volumes(self, model_location):
686687
for filename in model_data_source.get_file_list():
687688
if tarfile.is_tarfile(filename):
688689
with tarfile.open(filename) as tar:
689-
tar.extractall(path=model_data_source.get_root_dir())
690+
check_tarfile_data_filter_attribute()
691+
tar.extractall(path=model_data_source.get_root_dir(), filter="data")
690692

691693
volumes.append(_Volume(model_data_source.get_root_dir(), "/opt/ml/model"))
692694

src/sagemaker/serve/model_server/djl_serving/prepare.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import List
2121
from pathlib import Path
2222

23-
from sagemaker.utils import _tmpdir
23+
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
2424
from sagemaker.s3 import S3Downloader
2525
from sagemaker.djl_inference import DJLModel
2626
from sagemaker.djl_inference.model import _read_existing_serving_properties
@@ -53,7 +53,8 @@ def _extract_js_resource(js_model_dir: str, js_id: str):
5353
"""Uncompress the jumpstart resource"""
5454
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
5555
with tarfile.open(str(tmp_sourcedir)) as resources:
56-
resources.extractall(path=js_model_dir)
56+
check_tarfile_data_filter_attribute()
57+
resources.extractall(path=js_model_dir, filter="data")
5758

5859

5960
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path):

src/sagemaker/serve/model_server/tgi/prepare.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pathlib import Path
2020

2121
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
22-
from sagemaker.utils import _tmpdir
22+
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
2323
from sagemaker.s3 import S3Downloader
2424

2525
logger = logging.getLogger(__name__)
@@ -29,7 +29,8 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
2929
"""Uncompress the jumpstart resource"""
3030
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
3131
with tarfile.open(str(tmp_sourcedir)) as resources:
32-
resources.extractall(path=code_dir)
32+
check_tarfile_data_filter_attribute()
33+
resources.extractall(path=code_dir, filter="data")
3334

3435

3536
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:

src/sagemaker/utils.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import random
2323
import re
2424
import shutil
25+
import sys
2526
import tarfile
2627
import tempfile
2728
import time
@@ -591,7 +592,8 @@ def _create_or_update_code_dir(
591592
download_file_from_url(source_directory, local_code_path, sagemaker_session)
592593

593594
with tarfile.open(name=local_code_path, mode="r:gz") as t:
594-
t.extractall(path=code_dir)
595+
check_tarfile_data_filter_attribute()
596+
t.extractall(path=code_dir, filter="data")
595597

596598
elif source_directory:
597599
if os.path.exists(code_dir):
@@ -628,7 +630,8 @@ def _extract_model(model_uri, sagemaker_session, tmp):
628630
else:
629631
local_model_path = model_uri.replace("file://", "")
630632
with tarfile.open(name=local_model_path, mode="r:gz") as t:
631-
t.extractall(path=tmp_model_dir)
633+
check_tarfile_data_filter_attribute()
634+
t.extractall(path=tmp_model_dir, filter="data")
632635
return tmp_model_dir
633636

634637

@@ -1489,3 +1492,25 @@ def format_tags(tags: Tags) -> List[TagsDict]:
14891492
return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()]
14901493

14911494
return tags
1495+
1496+
1497+
class PythonVersionError(Exception):
1498+
"""Raise when a secure [/patched] version of Python is not used."""
1499+
1500+
1501+
def check_tarfile_data_filter_attribute():
1502+
"""Check if tarfile has data_filter utility.
1503+
1504+
Tarfile-data_filter utility has guardrails against untrusted de-serialisation.
1505+
1506+
Raises:
1507+
PythonVersionError: if `tarfile.data_filter` is not available.
1508+
"""
1509+
# The function and it's usages can be deprecated post support of python >= 3.12
1510+
if not hasattr(tarfile, "data_filter"):
1511+
raise PythonVersionError(
1512+
f"Since tarfile extraction is unsafe the operation is prohibited "
1513+
f"per PEP-721. Please update your Python [{sys.version}] "
1514+
f"to latest patch [refer to https://www.python.org/downloads/] "
1515+
f"to consume the security patch"
1516+
)

src/sagemaker/workflow/_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
Step,
3333
ConfigurableRetryStep,
3434
)
35-
from sagemaker.utils import _save_model, download_file_from_url, format_tags
35+
from sagemaker.utils import (
36+
_save_model,
37+
download_file_from_url,
38+
format_tags,
39+
check_tarfile_data_filter_attribute,
40+
)
3641
from sagemaker.workflow.retry import RetryPolicy
3742
from sagemaker.workflow.utilities import trim_request_dict
3843

@@ -257,7 +262,8 @@ def _inject_repack_script_and_launcher(self):
257262
download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session)
258263

259264
with tarfile.open(name=old_targz_path, mode="r:gz") as t:
260-
t.extractall(path=targz_contents_dir)
265+
check_tarfile_data_filter_attribute()
266+
t.extractall(path=targz_contents_dir, filter="data")
261267

262268
shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT))
263269
with open(

tests/integ/s3_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import boto3
2020
from six.moves.urllib.parse import urlparse
2121

22+
from sagemaker.utils import check_tarfile_data_filter_attribute
23+
2224

2325
def assert_s3_files_exist(sagemaker_session, s3_url, files):
2426
parsed_url = urlparse(s3_url)
@@ -55,4 +57,5 @@ def extract_files_from_s3(s3_url, tmpdir, sagemaker_session):
5557
s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model)
5658

5759
with tarfile.open(model, "r") as tar_file:
58-
tar_file.extractall(tmpdir)
60+
check_tarfile_data_filter_attribute()
61+
tar_file.extractall(tmpdir, filter="data")

tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,4 +272,4 @@ def test_extract_js_resources_success(self, mock_tarfile, mock_path):
272272

273273
mock_path.assert_called_once_with(js_model_dir)
274274
mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz")
275-
mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir)
275+
mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir, filter="data")

tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,4 @@ def test_extract_js_resources_success(self, mock_tarfile, mock_path):
156156

157157
mock_path.assert_called_once_with(js_model_dir)
158158
mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz")
159-
mock_resource_obj.extractall.assert_called_once_with(path=code_dir)
159+
mock_resource_obj.extractall.assert_called_once_with(path=code_dir, filter="data")

tests/unit/test_fw_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from mock import Mock, patch
2525

2626
from sagemaker import fw_utils
27-
from sagemaker.utils import name_from_image
27+
from sagemaker.utils import name_from_image, check_tarfile_data_filter_attribute
2828
from sagemaker.session_settings import SessionSettings
2929
from sagemaker.instance_group import InstanceGroup
3030

@@ -424,7 +424,8 @@ def list_tar_files(folder, tar_ball, tmpdir):
424424
startpath = str(tmpdir.ensure(folder, dir=True))
425425

426426
with tarfile.open(name=tar_ball, mode="r:gz") as t:
427-
t.extractall(path=startpath)
427+
check_tarfile_data_filter_attribute()
428+
t.extractall(path=startpath, filter="data")
428429

429430
def walk():
430431
for root, dirs, files in os.walk(startpath):

tests/unit/test_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
resolve_nested_dict_value_from_config,
4343
update_list_of_dicts_with_values_from_config,
4444
volume_size_supported,
45+
PythonVersionError,
46+
check_tarfile_data_filter_attribute,
4547
)
4648
from tests.unit.sagemaker.workflow.helpers import CustomStep
4749
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
@@ -1748,3 +1750,15 @@ def test_instance_family_from_full_instance_type(self):
17481750

17491751
for instance_type, family in instance_type_to_family_test_dict.items():
17501752
self.assertEqual(family, get_instance_type_family(instance_type))
1753+
1754+
1755+
class TestCheckTarfileDataFilterAttribute(TestCase):
1756+
def test_check_tarfile_data_filter_attribute_unhappy_case(self):
1757+
with pytest.raises(PythonVersionError):
1758+
with patch("tarfile.data_filter", None):
1759+
delattr(tarfile, "data_filter")
1760+
check_tarfile_data_filter_attribute()
1761+
1762+
def test_check_tarfile_data_filter_attribute_happy_case(self):
1763+
with patch("tarfile.data_filter", "some_value"):
1764+
check_tarfile_data_filter_attribute()

0 commit comments

Comments
 (0)