Skip to content

Commit 7000f25

Browse files
knikurebencrabtree
authored andcommitted
fix: Create custom tarfile extractall util to fix backward compatibility issue (aws#4476)
* fix: Create custom tarfile extractall util to fix backward compatibility issue * Address review comments * fix logger.error statements
1 parent b426c21 commit 7000f25

File tree

9 files changed

+240
-51
lines changed

9 files changed

+240
-51
lines changed

src/sagemaker/local/image.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import sagemaker.local.data
4141
import sagemaker.local.utils
4242
import sagemaker.utils
43-
from sagemaker.utils import check_tarfile_data_filter_attribute
43+
from sagemaker.utils import custom_extractall_tarfile
4444

4545
CONTAINER_PREFIX = "algo"
4646
STUDIO_HOST_NAME = "sagemaker-local"
@@ -687,8 +687,7 @@ def _prepare_serving_volumes(self, model_location):
687687
for filename in model_data_source.get_file_list():
688688
if tarfile.is_tarfile(filename):
689689
with tarfile.open(filename) as tar:
690-
check_tarfile_data_filter_attribute()
691-
tar.extractall(path=model_data_source.get_root_dir(), filter="data")
690+
custom_extractall_tarfile(tar, model_data_source.get_root_dir())
692691

693692
volumes.append(_Volume(model_data_source.get_root_dir(), "/opt/ml/model"))
694693

src/sagemaker/serve/model_server/djl_serving/prepare.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import List
2121
from pathlib import Path
2222

23-
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
23+
from sagemaker.utils import _tmpdir, custom_extractall_tarfile
2424
from sagemaker.s3 import S3Downloader
2525
from sagemaker.djl_inference import DJLModel
2626
from sagemaker.djl_inference.model import _read_existing_serving_properties
@@ -53,8 +53,7 @@ def _extract_js_resource(js_model_dir: str, js_id: str):
5353
"""Uncompress the jumpstart resource"""
5454
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
5555
with tarfile.open(str(tmp_sourcedir)) as resources:
56-
check_tarfile_data_filter_attribute()
57-
resources.extractall(path=js_model_dir, filter="data")
56+
custom_extractall_tarfile(resources, js_model_dir)
5857

5958

6059
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path):

src/sagemaker/serve/model_server/tgi/prepare.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pathlib import Path
2020

2121
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
22-
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
22+
from sagemaker.utils import _tmpdir, custom_extractall_tarfile
2323
from sagemaker.s3 import S3Downloader
2424

2525
logger = logging.getLogger(__name__)
@@ -29,8 +29,7 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
2929
"""Uncompress the jumpstart resource"""
3030
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
3131
with tarfile.open(str(tmp_sourcedir)) as resources:
32-
check_tarfile_data_filter_attribute()
33-
resources.extractall(path=code_dir, filter="data")
32+
custom_extractall_tarfile(resources, code_dir)
3433

3534

3635
def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:

src/sagemaker/utils.py

+87-20
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import random
2323
import re
2424
import shutil
25-
import sys
2625
import tarfile
2726
import tempfile
2827
import time
@@ -31,6 +30,7 @@
3130
import abc
3231
import uuid
3332
from datetime import datetime
33+
from os.path import abspath, realpath, dirname, normpath, join as joinpath
3434

3535
from importlib import import_module
3636
import botocore
@@ -592,8 +592,7 @@ def _create_or_update_code_dir(
592592
download_file_from_url(source_directory, local_code_path, sagemaker_session)
593593

594594
with tarfile.open(name=local_code_path, mode="r:gz") as t:
595-
check_tarfile_data_filter_attribute()
596-
t.extractall(path=code_dir, filter="data")
595+
custom_extractall_tarfile(t, code_dir)
597596

598597
elif source_directory:
599598
if os.path.exists(code_dir):
@@ -630,8 +629,7 @@ def _extract_model(model_uri, sagemaker_session, tmp):
630629
else:
631630
local_model_path = model_uri.replace("file://", "")
632631
with tarfile.open(name=local_model_path, mode="r:gz") as t:
633-
check_tarfile_data_filter_attribute()
634-
t.extractall(path=tmp_model_dir, filter="data")
632+
custom_extractall_tarfile(t, tmp_model_dir)
635633
return tmp_model_dir
636634

637635

@@ -1494,23 +1492,92 @@ def format_tags(tags: Tags) -> List[TagsDict]:
14941492
return tags
14951493

14961494

1497-
class PythonVersionError(Exception):
1498-
"""Raise when a secure [/patched] version of Python is not used."""
1495+
def _get_resolved_path(path):
1496+
"""Return the normalized absolute path of a given path.
14991497
1498+
abspath - returns the absolute path without resolving symlinks
1499+
realpath - resolves the symlinks and gets the actual path
1500+
normpath - normalizes paths (e.g. remove redudant separators)
1501+
and handles platform-specific differences
1502+
"""
1503+
return normpath(realpath(abspath(path)))
15001504

1501-
def check_tarfile_data_filter_attribute():
1502-
"""Check if tarfile has data_filter utility.
15031505

1504-
Tarfile-data_filter utility has guardrails against untrusted de-serialisation.
1506+
def _is_bad_path(path, base):
1507+
"""Checks if the joined path (base directory + file path) is rooted under the base directory
15051508
1506-
Raises:
1507-
PythonVersionError: if `tarfile.data_filter` is not available.
1509+
Ensuring that the file does not attempt to access paths
1510+
outside the expected directory structure.
1511+
1512+
Args:
1513+
path (str): The file path.
1514+
base (str): The base directory.
1515+
1516+
Returns:
1517+
bool: True if the path is not rooted under the base directory, False otherwise.
15081518
"""
1509-
# The function and it's usages can be deprecated post support of python >= 3.12
1510-
if not hasattr(tarfile, "data_filter"):
1511-
raise PythonVersionError(
1512-
f"Since tarfile extraction is unsafe the operation is prohibited "
1513-
f"per PEP-721. Please update your Python [{sys.version}] "
1514-
f"to latest patch [refer to https://www.python.org/downloads/] "
1515-
f"to consume the security patch"
1516-
)
1519+
# joinpath will ignore base if path is absolute
1520+
return not _get_resolved_path(joinpath(base, path)).startswith(base)
1521+
1522+
1523+
def _is_bad_link(info, base):
1524+
"""Checks if the link is rooted under the base directory.
1525+
1526+
Ensuring that the link does not attempt to access paths outside the expected directory structure
1527+
1528+
Args:
1529+
info (tarfile.TarInfo): The tar file info.
1530+
base (str): The base directory.
1531+
1532+
Returns:
1533+
bool: True if the link is not rooted under the base directory, False otherwise.
1534+
"""
1535+
# Links are interpreted relative to the directory containing the link
1536+
tip = _get_resolved_path(joinpath(base, dirname(info.name)))
1537+
return _is_bad_path(info.linkname, base=tip)
1538+
1539+
1540+
def _get_safe_members(members):
1541+
"""A generator that yields members that are safe to extract.
1542+
1543+
It filters out bad paths and bad links.
1544+
1545+
Args:
1546+
members (list): A list of members to check.
1547+
1548+
Yields:
1549+
tarfile.TarInfo: The tar file info.
1550+
"""
1551+
base = _get_resolved_path(".")
1552+
1553+
for file_info in members:
1554+
if _is_bad_path(file_info.name, base):
1555+
logger.error("%s is blocked (illegal path)", file_info.name)
1556+
elif file_info.issym() and _is_bad_link(file_info, base):
1557+
logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname)
1558+
elif file_info.islnk() and _is_bad_link(file_info, base):
1559+
logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname)
1560+
else:
1561+
yield file_info
1562+
1563+
1564+
def custom_extractall_tarfile(tar, extract_path):
1565+
"""Extract a tarfile, optionally using data_filter if available.
1566+
1567+
# TODO: The function and it's usages can be deprecated once SageMaker Python SDK
1568+
is upgraded to use Python 3.12+
1569+
1570+
If the tarfile has a data_filter attribute, it will be used to extract the contents of the file.
1571+
Otherwise, the _get_safe_members function will be used to filter bad paths and bad links.
1572+
1573+
Args:
1574+
tar (tarfile.TarFile): The opened tarfile object.
1575+
extract_path (str): The path to extract the contents of the tarfile.
1576+
1577+
Returns:
1578+
None
1579+
"""
1580+
if hasattr(tarfile, "data_filter"):
1581+
tar.extractall(path=extract_path, filter="data")
1582+
else:
1583+
tar.extractall(path=extract_path, members=_get_safe_members(tar))

src/sagemaker/workflow/_repack_model.py

+97-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import argparse
17+
import logging
1718
import os
1819
import shutil
1920
import tarfile
@@ -33,6 +34,101 @@
3334
# repacking is some short-lived hackery, right??
3435
from distutils.dir_util import copy_tree
3536

37+
from os.path import abspath, realpath, dirname, normpath, join as joinpath
38+
39+
logger = logging.getLogger(__name__)
40+
41+
42+
def _get_resolved_path(path):
43+
"""Return the normalized absolute path of a given path.
44+
45+
abspath - returns the absolute path without resolving symlinks
46+
realpath - resolves the symlinks and gets the actual path
47+
normpath - normalizes paths (e.g. remove redudant separators)
48+
and handles platform-specific differences
49+
"""
50+
return normpath(realpath(abspath(path)))
51+
52+
53+
def _is_bad_path(path, base):
54+
"""Checks if the joined path (base directory + file path) is rooted under the base directory
55+
56+
Ensuring that the file does not attempt to access paths
57+
outside the expected directory structure.
58+
59+
Args:
60+
path (str): The file path.
61+
base (str): The base directory.
62+
63+
Returns:
64+
bool: True if the path is not rooted under the base directory, False otherwise.
65+
"""
66+
# joinpath will ignore base if path is absolute
67+
return not _get_resolved_path(joinpath(base, path)).startswith(base)
68+
69+
70+
def _is_bad_link(info, base):
71+
"""Checks if the link is rooted under the base directory.
72+
73+
Ensuring that the link does not attempt to access paths outside the expected directory structure
74+
75+
Args:
76+
info (tarfile.TarInfo): The tar file info.
77+
base (str): The base directory.
78+
79+
Returns:
80+
bool: True if the link is not rooted under the base directory, False otherwise.
81+
"""
82+
# Links are interpreted relative to the directory containing the link
83+
tip = _get_resolved_path(joinpath(base, dirname(info.name)))
84+
return _is_bad_path(info.linkname, base=tip)
85+
86+
87+
def _get_safe_members(members):
88+
"""A generator that yields members that are safe to extract.
89+
90+
It filters out bad paths and bad links.
91+
92+
Args:
93+
members (list): A list of members to check.
94+
95+
Yields:
96+
tarfile.TarInfo: The tar file info.
97+
"""
98+
base = _get_resolved_path(".")
99+
100+
for file_info in members:
101+
if _is_bad_path(file_info.name, base):
102+
logger.error("%s is blocked (illegal path)", file_info.name)
103+
elif file_info.issym() and _is_bad_link(file_info, base):
104+
logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname)
105+
elif file_info.islnk() and _is_bad_link(file_info, base):
106+
logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname)
107+
else:
108+
yield file_info
109+
110+
111+
def custom_extractall_tarfile(tar, extract_path):
112+
"""Extract a tarfile, optionally using data_filter if available.
113+
114+
# TODO: The function and it's usages can be deprecated once SageMaker Python SDK
115+
is upgraded to use Python 3.12+
116+
117+
If the tarfile has a data_filter attribute, it will be used to extract the contents of the file.
118+
Otherwise, the _get_safe_members function will be used to filter bad paths and bad links.
119+
120+
Args:
121+
tar (tarfile.TarFile): The opened tarfile object.
122+
extract_path (str): The path to extract the contents of the tarfile.
123+
124+
Returns:
125+
None
126+
"""
127+
if hasattr(tarfile, "data_filter"):
128+
tar.extractall(path=extract_path, filter="data")
129+
else:
130+
tar.extractall(path=extract_path, members=_get_safe_members(tar))
131+
36132

37133
def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
38134
"""Repack custom dependencies and code into an existing model TAR archive
@@ -60,7 +156,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
60156
# extract the contents of the previous training job's model archive to the "src"
61157
# directory of this training job
62158
with tarfile.open(name=local_path, mode="r:gz") as tf:
63-
tf.extractall(path=src_dir)
159+
custom_extractall_tarfile(tf, src_dir)
64160

65161
if source_dir:
66162
# copy /opt/ml/code to code/

src/sagemaker/workflow/_utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
_save_model,
3737
download_file_from_url,
3838
format_tags,
39-
check_tarfile_data_filter_attribute,
39+
custom_extractall_tarfile,
4040
)
4141
from sagemaker.workflow.retry import RetryPolicy
4242
from sagemaker.workflow.utilities import trim_request_dict
@@ -262,8 +262,7 @@ def _inject_repack_script_and_launcher(self):
262262
download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session)
263263

264264
with tarfile.open(name=old_targz_path, mode="r:gz") as t:
265-
check_tarfile_data_filter_attribute()
266-
t.extractall(path=targz_contents_dir, filter="data")
265+
custom_extractall_tarfile(t, targz_contents_dir)
267266

268267
shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT))
269268
with open(

tests/integ/s3_utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import boto3
2020
from six.moves.urllib.parse import urlparse
2121

22-
from sagemaker.utils import check_tarfile_data_filter_attribute
22+
from sagemaker.utils import custom_extractall_tarfile
2323

2424

2525
def assert_s3_files_exist(sagemaker_session, s3_url, files):
@@ -57,5 +57,4 @@ def extract_files_from_s3(s3_url, tmpdir, sagemaker_session):
5757
s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model)
5858

5959
with tarfile.open(model, "r") as tar_file:
60-
check_tarfile_data_filter_attribute()
61-
tar_file.extractall(tmpdir, filter="data")
60+
custom_extractall_tarfile(tar_file, tmpdir)

tests/unit/test_fw_utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from mock import Mock, patch
2525

2626
from sagemaker import fw_utils
27-
from sagemaker.utils import name_from_image, check_tarfile_data_filter_attribute
27+
from sagemaker.utils import name_from_image, custom_extractall_tarfile
2828
from sagemaker.session_settings import SessionSettings
2929
from sagemaker.instance_group import InstanceGroup
3030

@@ -424,8 +424,7 @@ def list_tar_files(folder, tar_ball, tmpdir):
424424
startpath = str(tmpdir.ensure(folder, dir=True))
425425

426426
with tarfile.open(name=tar_ball, mode="r:gz") as t:
427-
check_tarfile_data_filter_attribute()
428-
t.extractall(path=startpath, filter="data")
427+
custom_extractall_tarfile(t, startpath)
429428

430429
def walk():
431430
for root, dirs, files in os.walk(startpath):

0 commit comments

Comments
 (0)