Skip to content

Commit 0809b05

Browse files
authored
fix: Add SELinux label to local docker volumes (#3790)
1 parent 8da92f7 commit 0809b05

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

src/sagemaker/local/image.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@
4949
TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME"
5050
S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL"
5151

52+
# SELinux Enabled
53+
SELINUX_ENABLED = os.environ.get("SAGEMAKER_LOCAL_SELINUX_ENABLED", "False").lower() in [
54+
"1",
55+
"true",
56+
"yes",
57+
]
58+
5259
logger = logging.getLogger(__name__)
5360

5461

@@ -349,6 +356,7 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name):
349356
# Gather the artifacts from all nodes into artifacts/model and artifacts/output
350357
for host in self.hosts:
351358
volumes = compose_data["services"][str(host)]["volumes"]
359+
volumes = [v[:-2] if v.endswith(":z") else v for v in volumes]
352360
for volume in volumes:
353361
if re.search(r"^[A-Za-z]:", volume):
354362
unit, host_dir, container_dir = volume.split(":")
@@ -887,10 +895,14 @@ def __init__(self, host_dir, container_dir=None, channel=None):
887895

888896
self.container_dir = container_dir if container_dir else "/opt/ml/input/data/" + channel
889897
self.host_dir = host_dir
898+
map_format = "{}:{}"
899+
if platform.system() == "Linux" and SELINUX_ENABLED:
900+
# Support mounting shared volumes in SELinux enabled hosts
901+
map_format += ":z"
890902
if platform.system() == "Darwin" and host_dir.startswith("/var"):
891903
self.host_dir = os.path.join("/private", host_dir)
892904

893-
self.map = "{}:{}".format(self.host_dir, self.container_dir)
905+
self.map = map_format.format(self.host_dir, self.container_dir)
894906

895907

896908
def _stream_output(process):

tests/unit/sagemaker/local/test_local_image.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from mock import patch, Mock, MagicMock
3131

3232
import sagemaker
33-
from sagemaker.local.image import _SageMakerContainer, _aws_credentials
33+
from sagemaker.local.image import _SageMakerContainer, _Volume, _aws_credentials
3434

3535
REGION = "us-west-2"
3636
BUCKET_NAME = "mybucket"
@@ -513,6 +513,7 @@ def test_train_local_code(get_data_source_instance, tmpdir, sagemaker_session):
513513
assert config["services"][h]["image"] == image
514514
assert config["services"][h]["command"] == "train"
515515
volumes = config["services"][h]["volumes"]
516+
volumes = [v[:-2] if v.endswith(":z") else v for v in volumes]
516517
assert "%s:/opt/ml/code" % "/tmp/code" in volumes
517518
assert "%s:/opt/ml/shared" % shared_folder_path in volumes
518519

@@ -564,9 +565,26 @@ def test_train_local_intermediate_output(get_data_source_instance, tmpdir, sagem
564565
assert config["services"][h]["image"] == image
565566
assert config["services"][h]["command"] == "train"
566567
volumes = config["services"][h]["volumes"]
568+
volumes = [v[:-2] if v.endswith(":z") else v for v in volumes]
567569
assert "%s:/opt/ml/output/intermediate" % intermediate_folder_path in volumes
568570

569571

572+
@patch("platform.system", Mock(return_value="Linux"))
573+
@patch("sagemaker.local.image.SELINUX_ENABLED", Mock(return_value=True))
574+
def test_container_selinux_has_label(tmpdir):
575+
volume = _Volume(str(tmpdir), "/opt/ml/model")
576+
577+
assert volume.map.endswith(":z")
578+
579+
580+
@patch("platform.system", Mock(return_value="Darwin"))
581+
@patch("sagemaker.local.image.SELINUX_ENABLED", Mock(return_value=True))
582+
def test_container_has_selinux_no_label(tmpdir):
583+
volume = _Volume(str(tmpdir), "/opt/ml/model")
584+
585+
assert not volume.map.endswith(":z")
586+
587+
570588
def test_container_has_gpu_support(tmpdir, sagemaker_session):
571589
instance_count = 1
572590
image = "my-image"
@@ -650,6 +668,7 @@ def test_serve_local_code(tmpdir, sagemaker_session):
650668
assert config["services"][h]["command"] == "serve"
651669

652670
volumes = config["services"][h]["volumes"]
671+
volumes = [v[:-2] if v.endswith(":z") else v for v in volumes]
653672
assert "%s:/opt/ml/code" % "/tmp/code" in volumes
654673
assert (
655674
"SAGEMAKER_SUBMIT_DIRECTORY=/opt/ml/code"

0 commit comments

Comments
 (0)