Skip to content

fix: Create custom tarfile extractall util to fix backward compatibility issue #4476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import sagemaker.local.data
import sagemaker.local.utils
import sagemaker.utils
from sagemaker.utils import check_tarfile_data_filter_attribute
from sagemaker.utils import custom_extractall_tarfile

CONTAINER_PREFIX = "algo"
STUDIO_HOST_NAME = "sagemaker-local"
Expand Down Expand Up @@ -687,8 +687,7 @@ def _prepare_serving_volumes(self, model_location):
for filename in model_data_source.get_file_list():
if tarfile.is_tarfile(filename):
with tarfile.open(filename) as tar:
check_tarfile_data_filter_attribute()
tar.extractall(path=model_data_source.get_root_dir(), filter="data")
custom_extractall_tarfile(tar, model_data_source.get_root_dir())

volumes.append(_Volume(model_data_source.get_root_dir(), "/opt/ml/model"))

Expand Down
5 changes: 2 additions & 3 deletions src/sagemaker/serve/model_server/djl_serving/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List
from pathlib import Path

from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
from sagemaker.utils import _tmpdir, custom_extractall_tarfile
from sagemaker.s3 import S3Downloader
from sagemaker.djl_inference import DJLModel
from sagemaker.djl_inference.model import _read_existing_serving_properties
Expand Down Expand Up @@ -53,8 +53,7 @@ def _extract_js_resource(js_model_dir: str, js_id: str):
"""Uncompress the jumpstart resource"""
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
with tarfile.open(str(tmp_sourcedir)) as resources:
check_tarfile_data_filter_attribute()
resources.extractall(path=js_model_dir, filter="data")
custom_extractall_tarfile(resources, js_model_dir)


def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path):
Expand Down
5 changes: 2 additions & 3 deletions src/sagemaker/serve/model_server/tgi/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pathlib import Path

from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
from sagemaker.utils import _tmpdir, custom_extractall_tarfile
from sagemaker.s3 import S3Downloader

logger = logging.getLogger(__name__)
Expand All @@ -29,8 +29,7 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
"""Uncompress the jumpstart resource"""
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
with tarfile.open(str(tmp_sourcedir)) as resources:
check_tarfile_data_filter_attribute()
resources.extractall(path=code_dir, filter="data")
custom_extractall_tarfile(resources, code_dir)


def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:
Expand Down
107 changes: 87 additions & 20 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import random
import re
import shutil
import sys
import tarfile
import tempfile
import time
Expand All @@ -31,6 +30,7 @@
import abc
import uuid
from datetime import datetime
from os.path import abspath, realpath, dirname, normpath, join as joinpath

from importlib import import_module
import botocore
Expand Down Expand Up @@ -592,8 +592,7 @@ def _create_or_update_code_dir(
download_file_from_url(source_directory, local_code_path, sagemaker_session)

with tarfile.open(name=local_code_path, mode="r:gz") as t:
check_tarfile_data_filter_attribute()
t.extractall(path=code_dir, filter="data")
custom_extractall_tarfile(t, code_dir)

elif source_directory:
if os.path.exists(code_dir):
Expand Down Expand Up @@ -630,8 +629,7 @@ def _extract_model(model_uri, sagemaker_session, tmp):
else:
local_model_path = model_uri.replace("file://", "")
with tarfile.open(name=local_model_path, mode="r:gz") as t:
check_tarfile_data_filter_attribute()
t.extractall(path=tmp_model_dir, filter="data")
custom_extractall_tarfile(t, tmp_model_dir)
return tmp_model_dir


Expand Down Expand Up @@ -1494,23 +1492,92 @@ def format_tags(tags: Tags) -> List[TagsDict]:
return tags


class PythonVersionError(Exception):
"""Raise when a secure [/patched] version of Python is not used."""
def _get_resolved_path(path):
"""Return the normalized absolute path of a given path.

abspath - returns the absolute path without resolving symlinks
realpath - resolves the symlinks and gets the actual path
normpath - normalizes paths (e.g. remove redudant separators)
and handles platform-specific differences
"""
return normpath(realpath(abspath(path)))

def check_tarfile_data_filter_attribute():
"""Check if tarfile has data_filter utility.

Tarfile-data_filter utility has guardrails against untrusted de-serialisation.
def _is_bad_path(path, base):
"""Checks if the joined path (base directory + file path) is rooted under the base directory

Raises:
PythonVersionError: if `tarfile.data_filter` is not available.
Ensuring that the file does not attempt to access paths
outside the expected directory structure.

Args:
path (str): The file path.
base (str): The base directory.

Returns:
bool: True if the path is not rooted under the base directory, False otherwise.
"""
# The function and it's usages can be deprecated post support of python >= 3.12
if not hasattr(tarfile, "data_filter"):
raise PythonVersionError(
f"Since tarfile extraction is unsafe the operation is prohibited "
f"per PEP-721. Please update your Python [{sys.version}] "
f"to latest patch [refer to https://www.python.org/downloads/] "
f"to consume the security patch"
)
# joinpath will ignore base if path is absolute
return not _get_resolved_path(joinpath(base, path)).startswith(base)


def _is_bad_link(info, base):
"""Checks if the link is rooted under the base directory.

Ensuring that the link does not attempt to access paths outside the expected directory structure

Args:
info (tarfile.TarInfo): The tar file info.
base (str): The base directory.

Returns:
bool: True if the link is not rooted under the base directory, False otherwise.
"""
# Links are interpreted relative to the directory containing the link
tip = _get_resolved_path(joinpath(base, dirname(info.name)))
return _is_bad_path(info.linkname, base=tip)


def _get_safe_members(members):
"""A generator that yields members that are safe to extract.

It filters out bad paths and bad links.

Args:
members (list): A list of members to check.

Yields:
tarfile.TarInfo: The tar file info.
"""
base = _get_resolved_path(".")

for file_info in members:
if _is_bad_path(file_info.name, base):
logger.error("%s is blocked (illegal path)", file_info.name)
elif file_info.issym() and _is_bad_link(file_info, base):
logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname)
elif file_info.islnk() and _is_bad_link(file_info, base):
logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname)
else:
yield file_info


def custom_extractall_tarfile(tar, extract_path):
"""Extract a tarfile, optionally using data_filter if available.

# TODO: The function and it's usages can be deprecated once SageMaker Python SDK
is upgraded to use Python 3.12+

If the tarfile has a data_filter attribute, it will be used to extract the contents of the file.
Otherwise, the _get_safe_members function will be used to filter bad paths and bad links.

Args:
tar (tarfile.TarFile): The opened tarfile object.
extract_path (str): The path to extract the contents of the tarfile.

Returns:
None
"""
if hasattr(tarfile, "data_filter"):
tar.extractall(path=extract_path, filter="data")
else:
tar.extractall(path=extract_path, members=_get_safe_members(tar))
98 changes: 97 additions & 1 deletion src/sagemaker/workflow/_repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

import argparse
import logging
import os
import shutil
import tarfile
Expand All @@ -33,6 +34,101 @@
# repacking is some short-lived hackery, right??
from distutils.dir_util import copy_tree

from os.path import abspath, realpath, dirname, normpath, join as joinpath

logger = logging.getLogger(__name__)


def _get_resolved_path(path):
"""Return the normalized absolute path of a given path.

abspath - returns the absolute path without resolving symlinks
realpath - resolves the symlinks and gets the actual path
normpath - normalizes paths (e.g. remove redudant separators)
and handles platform-specific differences
"""
return normpath(realpath(abspath(path)))


def _is_bad_path(path, base):
"""Checks if the joined path (base directory + file path) is rooted under the base directory

Ensuring that the file does not attempt to access paths
outside the expected directory structure.

Args:
path (str): The file path.
base (str): The base directory.

Returns:
bool: True if the path is not rooted under the base directory, False otherwise.
"""
# joinpath will ignore base if path is absolute
return not _get_resolved_path(joinpath(base, path)).startswith(base)


def _is_bad_link(info, base):
"""Checks if the link is rooted under the base directory.

Ensuring that the link does not attempt to access paths outside the expected directory structure

Args:
info (tarfile.TarInfo): The tar file info.
base (str): The base directory.

Returns:
bool: True if the link is not rooted under the base directory, False otherwise.
"""
# Links are interpreted relative to the directory containing the link
tip = _get_resolved_path(joinpath(base, dirname(info.name)))
return _is_bad_path(info.linkname, base=tip)


def _get_safe_members(members):
"""A generator that yields members that are safe to extract.

It filters out bad paths and bad links.

Args:
members (list): A list of members to check.

Yields:
tarfile.TarInfo: The tar file info.
"""
base = _get_resolved_path(".")

for file_info in members:
if _is_bad_path(file_info.name, base):
logger.error("%s is blocked (illegal path)", file_info.name)
elif file_info.issym() and _is_bad_link(file_info, base):
logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname)
elif file_info.islnk() and _is_bad_link(file_info, base):
logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname)
else:
yield file_info


def custom_extractall_tarfile(tar, extract_path):
"""Extract a tarfile, optionally using data_filter if available.

# TODO: The function and it's usages can be deprecated once SageMaker Python SDK
is upgraded to use Python 3.12+

If the tarfile has a data_filter attribute, it will be used to extract the contents of the file.
Otherwise, the _get_safe_members function will be used to filter bad paths and bad links.

Args:
tar (tarfile.TarFile): The opened tarfile object.
extract_path (str): The path to extract the contents of the tarfile.

Returns:
None
"""
if hasattr(tarfile, "data_filter"):
tar.extractall(path=extract_path, filter="data")
else:
tar.extractall(path=extract_path, members=_get_safe_members(tar))


def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
"""Repack custom dependencies and code into an existing model TAR archive
Expand Down Expand Up @@ -60,7 +156,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
# extract the contents of the previous training job's model archive to the "src"
# directory of this training job
with tarfile.open(name=local_path, mode="r:gz") as tf:
tf.extractall(path=src_dir)
custom_extractall_tarfile(tf, src_dir)

if source_dir:
# copy /opt/ml/code to code/
Expand Down
5 changes: 2 additions & 3 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
_save_model,
download_file_from_url,
format_tags,
check_tarfile_data_filter_attribute,
custom_extractall_tarfile,
)
from sagemaker.workflow.retry import RetryPolicy
from sagemaker.workflow.utilities import trim_request_dict
Expand Down Expand Up @@ -262,8 +262,7 @@ def _inject_repack_script_and_launcher(self):
download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session)

with tarfile.open(name=old_targz_path, mode="r:gz") as t:
check_tarfile_data_filter_attribute()
t.extractall(path=targz_contents_dir, filter="data")
custom_extractall_tarfile(t, targz_contents_dir)

shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT))
with open(
Expand Down
5 changes: 2 additions & 3 deletions tests/integ/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import boto3
from six.moves.urllib.parse import urlparse

from sagemaker.utils import check_tarfile_data_filter_attribute
from sagemaker.utils import custom_extractall_tarfile


def assert_s3_files_exist(sagemaker_session, s3_url, files):
Expand Down Expand Up @@ -57,5 +57,4 @@ def extract_files_from_s3(s3_url, tmpdir, sagemaker_session):
s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model)

with tarfile.open(model, "r") as tar_file:
check_tarfile_data_filter_attribute()
tar_file.extractall(tmpdir, filter="data")
custom_extractall_tarfile(tar_file, tmpdir)
5 changes: 2 additions & 3 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mock import Mock, patch

from sagemaker import fw_utils
from sagemaker.utils import name_from_image, check_tarfile_data_filter_attribute
from sagemaker.utils import name_from_image, custom_extractall_tarfile
from sagemaker.session_settings import SessionSettings
from sagemaker.instance_group import InstanceGroup

Expand Down Expand Up @@ -424,8 +424,7 @@ def list_tar_files(folder, tar_ball, tmpdir):
startpath = str(tmpdir.ensure(folder, dir=True))

with tarfile.open(name=tar_ball, mode="r:gz") as t:
check_tarfile_data_filter_attribute()
t.extractall(path=startpath, filter="data")
custom_extractall_tarfile(t, startpath)

def walk():
for root, dirs, files in os.walk(startpath):
Expand Down
Loading