|
22 | 22 | import random
|
23 | 23 | import re
|
24 | 24 | import shutil
|
25 |
| -import sys |
26 | 25 | import tarfile
|
27 | 26 | import tempfile
|
28 | 27 | import time
|
|
31 | 30 | import abc
|
32 | 31 | import uuid
|
33 | 32 | from datetime import datetime
|
| 33 | +from os.path import abspath, realpath, dirname, normpath, join as joinpath |
34 | 34 |
|
35 | 35 | from importlib import import_module
|
36 | 36 | import botocore
|
@@ -592,8 +592,7 @@ def _create_or_update_code_dir(
|
592 | 592 | download_file_from_url(source_directory, local_code_path, sagemaker_session)
|
593 | 593 |
|
594 | 594 | with tarfile.open(name=local_code_path, mode="r:gz") as t:
|
595 |
| - check_tarfile_data_filter_attribute() |
596 |
| - t.extractall(path=code_dir, filter="data") |
| 595 | + custom_extractall_tarfile(t, code_dir) |
597 | 596 |
|
598 | 597 | elif source_directory:
|
599 | 598 | if os.path.exists(code_dir):
|
@@ -630,8 +629,7 @@ def _extract_model(model_uri, sagemaker_session, tmp):
|
630 | 629 | else:
|
631 | 630 | local_model_path = model_uri.replace("file://", "")
|
632 | 631 | with tarfile.open(name=local_model_path, mode="r:gz") as t:
|
633 |
| - check_tarfile_data_filter_attribute() |
634 |
| - t.extractall(path=tmp_model_dir, filter="data") |
| 632 | + custom_extractall_tarfile(t, tmp_model_dir) |
635 | 633 | return tmp_model_dir
|
636 | 634 |
|
637 | 635 |
|
@@ -1494,23 +1492,92 @@ def format_tags(tags: Tags) -> List[TagsDict]:
|
1494 | 1492 | return tags
|
1495 | 1493 |
|
1496 | 1494 |
|
1497 |
| -class PythonVersionError(Exception): |
1498 |
| - """Raise when a secure [/patched] version of Python is not used.""" |
| 1495 | +def _get_resolved_path(path): |
| 1496 | + """Return the normalized absolute path of a given path. |
1499 | 1497 |
|
| 1498 | + abspath - returns the absolute path without resolving symlinks |
| 1499 | + realpath - resolves the symlinks and gets the actual path |
| 1500 | + normpath - normalizes paths (e.g. remove redudant separators) |
| 1501 | + and handles platform-specific differences |
| 1502 | + """ |
| 1503 | + return normpath(realpath(abspath(path))) |
1500 | 1504 |
|
1501 |
| -def check_tarfile_data_filter_attribute(): |
1502 |
| - """Check if tarfile has data_filter utility. |
1503 | 1505 |
|
1504 |
| - Tarfile-data_filter utility has guardrails against untrusted de-serialisation. |
| 1506 | +def _is_bad_path(path, base): |
| 1507 | + """Checks if the joined path (base directory + file path) is rooted under the base directory |
1505 | 1508 |
|
1506 |
| - Raises: |
1507 |
| - PythonVersionError: if `tarfile.data_filter` is not available. |
| 1509 | + Ensuring that the file does not attempt to access paths |
| 1510 | + outside the expected directory structure. |
| 1511 | +
|
| 1512 | + Args: |
| 1513 | + path (str): The file path. |
| 1514 | + base (str): The base directory. |
| 1515 | +
|
| 1516 | + Returns: |
| 1517 | + bool: True if the path is not rooted under the base directory, False otherwise. |
1508 | 1518 | """
|
1509 |
| - # The function and it's usages can be deprecated post support of python >= 3.12 |
1510 |
| - if not hasattr(tarfile, "data_filter"): |
1511 |
| - raise PythonVersionError( |
1512 |
| - f"Since tarfile extraction is unsafe the operation is prohibited " |
1513 |
| - f"per PEP-721. Please update your Python [{sys.version}] " |
1514 |
| - f"to latest patch [refer to https://www.python.org/downloads/] " |
1515 |
| - f"to consume the security patch" |
1516 |
| - ) |
| 1519 | + # joinpath will ignore base if path is absolute |
| 1520 | + return not _get_resolved_path(joinpath(base, path)).startswith(base) |
| 1521 | + |
| 1522 | + |
| 1523 | +def _is_bad_link(info, base): |
| 1524 | + """Checks if the link is rooted under the base directory. |
| 1525 | +
|
| 1526 | + Ensuring that the link does not attempt to access paths outside the expected directory structure |
| 1527 | +
|
| 1528 | + Args: |
| 1529 | + info (tarfile.TarInfo): The tar file info. |
| 1530 | + base (str): The base directory. |
| 1531 | +
|
| 1532 | + Returns: |
| 1533 | + bool: True if the link is not rooted under the base directory, False otherwise. |
| 1534 | + """ |
| 1535 | + # Links are interpreted relative to the directory containing the link |
| 1536 | + tip = _get_resolved_path(joinpath(base, dirname(info.name))) |
| 1537 | + return _is_bad_path(info.linkname, base=tip) |
| 1538 | + |
| 1539 | + |
| 1540 | +def _get_safe_members(members): |
| 1541 | + """A generator that yields members that are safe to extract. |
| 1542 | +
|
| 1543 | + It filters out bad paths and bad links. |
| 1544 | +
|
| 1545 | + Args: |
| 1546 | + members (list): A list of members to check. |
| 1547 | +
|
| 1548 | + Yields: |
| 1549 | + tarfile.TarInfo: The tar file info. |
| 1550 | + """ |
| 1551 | + base = _get_resolved_path(".") |
| 1552 | + |
| 1553 | + for file_info in members: |
| 1554 | + if _is_bad_path(file_info.name, base): |
| 1555 | + logger.error("%s is blocked (illegal path)", file_info.name) |
| 1556 | + elif file_info.issym() and _is_bad_link(file_info, base): |
| 1557 | + logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname) |
| 1558 | + elif file_info.islnk() and _is_bad_link(file_info, base): |
| 1559 | + logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname) |
| 1560 | + else: |
| 1561 | + yield file_info |
| 1562 | + |
| 1563 | + |
| 1564 | +def custom_extractall_tarfile(tar, extract_path): |
| 1565 | + """Extract a tarfile, optionally using data_filter if available. |
| 1566 | +
|
| 1567 | + # TODO: The function and it's usages can be deprecated once SageMaker Python SDK |
| 1568 | + is upgraded to use Python 3.12+ |
| 1569 | +
|
| 1570 | + If the tarfile has a data_filter attribute, it will be used to extract the contents of the file. |
| 1571 | + Otherwise, the _get_safe_members function will be used to filter bad paths and bad links. |
| 1572 | +
|
| 1573 | + Args: |
| 1574 | + tar (tarfile.TarFile): The opened tarfile object. |
| 1575 | + extract_path (str): The path to extract the contents of the tarfile. |
| 1576 | +
|
| 1577 | + Returns: |
| 1578 | + None |
| 1579 | + """ |
| 1580 | + if hasattr(tarfile, "data_filter"): |
| 1581 | + tar.extractall(path=extract_path, filter="data") |
| 1582 | + else: |
| 1583 | + tar.extractall(path=extract_path, members=_get_safe_members(tar)) |
0 commit comments