Skip to content

Commit dbb82c1

Browse files
committed
fix: Create custom tarfile extractall util to fix backward compatibility issue
1 parent 427c7ba commit dbb82c1

File tree

9 files changed

+233
-51
lines changed

9 files changed

+233
-51
lines changed

src/sagemaker/local/image.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import sagemaker.local.data
4141
import sagemaker.local.utils
4242
import sagemaker.utils
43-
from sagemaker.utils import check_tarfile_data_filter_attribute
43+
from sagemaker.utils import custom_extractall_tarfile
4444

4545
CONTAINER_PREFIX = "algo"
4646
STUDIO_HOST_NAME = "sagemaker-local"
@@ -687,8 +687,7 @@ def _prepare_serving_volumes(self, model_location):
687687
for filename in model_data_source.get_file_list():
688688
if tarfile.is_tarfile(filename):
689689
with tarfile.open(filename) as tar:
690-
check_tarfile_data_filter_attribute()
691-
tar.extractall(path=model_data_source.get_root_dir(), filter="data")
690+
custom_extractall_tarfile(tar, model_data_source.get_root_dir())
692691

693692
volumes.append(_Volume(model_data_source.get_root_dir(), "/opt/ml/model"))
694693

src/sagemaker/serve/model_server/djl_serving/prepare.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import List
2121
from pathlib import Path
2222

23-
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
23+
from sagemaker.utils import _tmpdir, custom_extractall_tarfile
2424
from sagemaker.s3 import S3Downloader
2525
from sagemaker.djl_inference import DJLModel
2626
from sagemaker.djl_inference.model import _read_existing_serving_properties
@@ -53,8 +53,7 @@ def _extract_js_resource(js_model_dir: str, js_id: str):
5353
"""Uncompress the jumpstart resource"""
5454
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
5555
with tarfile.open(str(tmp_sourcedir)) as resources:
56-
check_tarfile_data_filter_attribute()
57-
resources.extractall(path=js_model_dir, filter="data")
56+
custom_extractall_tarfile(resources, js_model_dir)
5857

5958

6059
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path):

src/sagemaker/serve/model_server/tgi/prepare.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pathlib import Path
2020

2121
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
22-
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
22+
from sagemaker.utils import _tmpdir, custom_extractall_tarfile
2323
from sagemaker.s3 import S3Downloader
2424

2525
logger = logging.getLogger(__name__)
@@ -29,8 +29,7 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
2929
"""Uncompress the jumpstart resource"""
3030
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
3131
with tarfile.open(str(tmp_sourcedir)) as resources:
32-
check_tarfile_data_filter_attribute()
33-
resources.extractall(path=code_dir, filter="data")
32+
custom_extractall_tarfile(resources, code_dir)
3433

3534

3635
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:

src/sagemaker/utils.py

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import random
2323
import re
2424
import shutil
25-
import sys
2625
import tarfile
2726
import tempfile
2827
import time
@@ -31,6 +30,8 @@
3130
import abc
3231
import uuid
3332
from datetime import datetime
33+
from os.path import abspath, realpath, dirname, normpath, join as joinpath
34+
from sys import stderr
3435

3536
from importlib import import_module
3637
import botocore
@@ -592,8 +593,7 @@ def _create_or_update_code_dir(
592593
download_file_from_url(source_directory, local_code_path, sagemaker_session)
593594

594595
with tarfile.open(name=local_code_path, mode="r:gz") as t:
595-
check_tarfile_data_filter_attribute()
596-
t.extractall(path=code_dir, filter="data")
596+
custom_extractall_tarfile(t, code_dir)
597597

598598
elif source_directory:
599599
if os.path.exists(code_dir):
@@ -630,8 +630,7 @@ def _extract_model(model_uri, sagemaker_session, tmp):
630630
else:
631631
local_model_path = model_uri.replace("file://", "")
632632
with tarfile.open(name=local_model_path, mode="r:gz") as t:
633-
check_tarfile_data_filter_attribute()
634-
t.extractall(path=tmp_model_dir, filter="data")
633+
custom_extractall_tarfile(t, tmp_model_dir)
635634
return tmp_model_dir
636635

637636

@@ -1494,23 +1493,89 @@ def format_tags(tags: Tags) -> List[TagsDict]:
14941493
return tags
14951494

14961495

1497-
class PythonVersionError(Exception):
1498-
"""Raise when a secure [/patched] version of Python is not used."""
1496+
def _get_resolved_path(path):
1497+
"""Return the normalized absolute path of a given path.
14991498
1499+
abspath - returns the absolute path without resolving symlinks
1500+
realpath - resolves the symlinks and gets the actual path
1501+
normpath - normalizes paths (e.g. remove redudant separators)
1502+
and handles platform-specific differences
1503+
"""
1504+
return normpath(realpath(abspath(path)))
15001505

1501-
def check_tarfile_data_filter_attribute():
1502-
"""Check if tarfile has data_filter utility.
15031506

1504-
Tarfile-data_filter utility has guardrails against untrusted de-serialisation.
1507+
def _is_bad_path(path, base):
1508+
"""Checks if the joined path (base directory + file path) is rooted under the base directory
15051509
1506-
Raises:
1507-
PythonVersionError: if `tarfile.data_filter` is not available.
1510+
Ensuring that the file does not attempt to access paths
1511+
outside the expected directory structure.
1512+
1513+
Args:
1514+
path (str): The file path.
1515+
base (str): The base directory.
1516+
1517+
Returns:
1518+
bool: True if the path is not rooted under the base directory, False otherwise.
15081519
"""
1509-
# The function and it's usages can be deprecated post support of python >= 3.12
1510-
if not hasattr(tarfile, "data_filter"):
1511-
raise PythonVersionError(
1512-
f"Since tarfile extraction is unsafe the operation is prohibited "
1513-
f"per PEP-721. Please update your Python [{sys.version}] "
1514-
f"to latest patch [refer to https://www.python.org/downloads/] "
1515-
f"to consume the security patch"
1516-
)
1520+
# joinpath will ignore base if path is absolute
1521+
return not _get_resolved_path(joinpath(base, path)).startswith(base)
1522+
1523+
1524+
def _is_bad_link(info, base):
1525+
"""Checks if the link is rooted under the base directory.
1526+
1527+
Ensuring that the link does not attempt to access paths outside the expected directory structure
1528+
1529+
Args:
1530+
info (tarfile.TarInfo): The tar file info.
1531+
base (str): The base directory.
1532+
1533+
Returns:
1534+
bool: True if the link is not rooted under the base directory, False otherwise.
1535+
"""
1536+
# Links are interpreted relative to the directory containing the link
1537+
tip = _get_resolved_path(joinpath(base, dirname(info.name)))
1538+
return _is_bad_path(info.linkname, base=tip)
1539+
1540+
1541+
def safe_members(members):
1542+
"""A generator that yields members that are safe to extract.
1543+
1544+
It checks for bad paths and bad links.
1545+
1546+
Args:
1547+
members (list): A list of members to check.
1548+
1549+
Yields:
1550+
tarfile.TarInfo: The tar file info.
1551+
"""
1552+
base = _get_resolved_path(".")
1553+
1554+
for file_info in members:
1555+
if _is_bad_path(file_info.name, base):
1556+
print(stderr, file_info.name, "is blocked (illegal path)")
1557+
elif file_info.issym() and _is_bad_link(file_info, base):
1558+
print(stderr, file_info.name, "is blocked: Symlink to", file_info.linkname)
1559+
elif file_info.islnk() and _is_bad_link(file_info, base):
1560+
print(stderr, file_info.name, "is blocked: Hard link to", file_info.linkname)
1561+
else:
1562+
yield file_info
1563+
1564+
1565+
def custom_extractall_tarfile(tar, extract_path):
1566+
"""Extract a tarfile, optionally using data_filter if available.
1567+
1568+
If the tarfile has a data_filter attribute, it will be used to extract the contents of the file.
1569+
Otherwise, the safe_members function will be used to check for bad paths and bad links.
1570+
1571+
Args:
1572+
tar (tarfile.TarFile): The opened tarfile object.
1573+
extract_path (str): The path to extract the contents of the tarfile.
1574+
1575+
Returns:
1576+
None
1577+
"""
1578+
if hasattr(tar, "data_filter"):
1579+
tar.extractall(path=extract_path, filter="data")
1580+
else:
1581+
tar.extractall(path=extract_path, members=safe_members(tar))

src/sagemaker/workflow/_repack_model.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,97 @@
3333
# repacking is some short-lived hackery, right??
3434
from distutils.dir_util import copy_tree
3535

36+
from os.path import abspath, realpath, dirname, normpath, join as joinpath
37+
from sys import stderr
38+
39+
40+
def _get_resolved_path(path):
41+
"""Return the normalized absolute path of a given path.
42+
43+
abspath - returns the absolute path without resolving symlinks
44+
realpath - resolves the symlinks and gets the actual path
45+
normpath - normalizes paths (e.g. remove redudant separators)
46+
and handles platform-specific differences
47+
"""
48+
return normpath(realpath(abspath(path)))
49+
50+
51+
def _is_bad_path(path, base):
52+
"""Checks if the joined path (base directory + file path) is rooted under the base directory
53+
54+
Ensuring that the file does not attempt to access paths
55+
outside the expected directory structure.
56+
57+
Args:
58+
path (str): The file path.
59+
base (str): The base directory.
60+
61+
Returns:
62+
bool: True if the path is not rooted under the base directory, False otherwise.
63+
"""
64+
# joinpath will ignore base if path is absolute
65+
return not _get_resolved_path(joinpath(base, path)).startswith(base)
66+
67+
68+
def _is_bad_link(info, base):
69+
"""Checks if the link is rooted under the base directory.
70+
71+
Ensuring that the link does not attempt to access paths outside the expected directory structure
72+
73+
Args:
74+
info (tarfile.TarInfo): The tar file info.
75+
base (str): The base directory.
76+
77+
Returns:
78+
bool: True if the link is not rooted under the base directory, False otherwise.
79+
"""
80+
# Links are interpreted relative to the directory containing the link
81+
tip = _get_resolved_path(joinpath(base, dirname(info.name)))
82+
return _is_bad_path(info.linkname, base=tip)
83+
84+
85+
def safe_members(members):
86+
"""A generator that yields members that are safe to extract.
87+
88+
It checks for bad paths and bad links.
89+
90+
Args:
91+
members (list): A list of members to check.
92+
93+
Yields:
94+
tarfile.TarInfo: The tar file info.
95+
"""
96+
base = _get_resolved_path(".")
97+
98+
for file_info in members:
99+
if _is_bad_path(file_info.name, base):
100+
print(stderr, file_info.name, "is blocked (illegal path)")
101+
elif file_info.issym() and _is_bad_link(file_info, base):
102+
print(stderr, file_info.name, "is blocked: Symlink to", file_info.linkname)
103+
elif file_info.islnk() and _is_bad_link(file_info, base):
104+
print(stderr, file_info.name, "is blocked: Hard link to", file_info.linkname)
105+
else:
106+
yield file_info
107+
108+
109+
def custom_extractall_tarfile(tar, extract_path):
110+
"""Extract a tarfile, optionally using data_filter if available.
111+
112+
If the tarfile has a data_filter attribute, it will be used to extract the contents of the file.
113+
Otherwise, the safe_members function will be used to check for bad paths and bad links.
114+
115+
Args:
116+
tar (tarfile.TarFile): The opened tarfile object.
117+
extract_path (str): The path to extract the contents of the tarfile.
118+
119+
Returns:
120+
None
121+
"""
122+
if hasattr(tar, "data_filter"):
123+
tar.extractall(path=extract_path, filter="data")
124+
else:
125+
tar.extractall(path=extract_path, members=safe_members(tar))
126+
36127

37128
def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
38129
"""Repack custom dependencies and code into an existing model TAR archive
@@ -60,7 +151,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
60151
# extract the contents of the previous training job's model archive to the "src"
61152
# directory of this training job
62153
with tarfile.open(name=local_path, mode="r:gz") as tf:
63-
tf.extractall(path=src_dir)
154+
custom_extractall_tarfile(tf, src_dir)
64155

65156
if source_dir:
66157
# copy /opt/ml/code to code/

src/sagemaker/workflow/_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
_save_model,
3737
download_file_from_url,
3838
format_tags,
39-
check_tarfile_data_filter_attribute,
39+
custom_extractall_tarfile,
4040
)
4141
from sagemaker.workflow.retry import RetryPolicy
4242
from sagemaker.workflow.utilities import trim_request_dict
@@ -262,8 +262,7 @@ def _inject_repack_script_and_launcher(self):
262262
download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session)
263263

264264
with tarfile.open(name=old_targz_path, mode="r:gz") as t:
265-
check_tarfile_data_filter_attribute()
266-
t.extractall(path=targz_contents_dir, filter="data")
265+
custom_extractall_tarfile(t, targz_contents_dir)
267266

268267
shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT))
269268
with open(

tests/integ/s3_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import boto3
2020
from six.moves.urllib.parse import urlparse
2121

22-
from sagemaker.utils import check_tarfile_data_filter_attribute
22+
from sagemaker.utils import custom_extractall_tarfile
2323

2424

2525
def assert_s3_files_exist(sagemaker_session, s3_url, files):
@@ -57,5 +57,4 @@ def extract_files_from_s3(s3_url, tmpdir, sagemaker_session):
5757
s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model)
5858

5959
with tarfile.open(model, "r") as tar_file:
60-
check_tarfile_data_filter_attribute()
61-
tar_file.extractall(tmpdir, filter="data")
60+
custom_extractall_tarfile(tar_file, tmpdir)

tests/unit/test_fw_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from mock import Mock, patch
2525

2626
from sagemaker import fw_utils
27-
from sagemaker.utils import name_from_image, check_tarfile_data_filter_attribute
27+
from sagemaker.utils import name_from_image, custom_extractall_tarfile
2828
from sagemaker.session_settings import SessionSettings
2929
from sagemaker.instance_group import InstanceGroup
3030

@@ -424,8 +424,7 @@ def list_tar_files(folder, tar_ball, tmpdir):
424424
startpath = str(tmpdir.ensure(folder, dir=True))
425425

426426
with tarfile.open(name=tar_ball, mode="r:gz") as t:
427-
check_tarfile_data_filter_attribute()
428-
t.extractall(path=startpath, filter="data")
427+
custom_extractall_tarfile(t, startpath)
429428

430429
def walk():
431430
for root, dirs, files in os.walk(startpath):

0 commit comments

Comments
 (0)