diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 7aecdf3dd6..a2a61d0b3c 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -43,7 +43,6 @@ DOCKER_COMPOSE_HTTP_TIMEOUT_ENV = "COMPOSE_HTTP_TIMEOUT" DOCKER_COMPOSE_HTTP_TIMEOUT = "120" - # Environment variables to be set during training REGION_ENV_NAME = "AWS_REGION" TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME" @@ -256,7 +255,11 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name): for host in self.hosts: volumes = compose_data["services"][str(host)]["volumes"] for volume in volumes: - host_dir, container_dir = volume.split(":") + if re.search(r"^[A-Za-z]:", volume): + unit, host_dir, container_dir = volume.split(":") + host_dir = unit + ":" + host_dir + else: + host_dir, container_dir = volume.split(":") if container_dir == "/opt/ml/model": sagemaker.local.utils.recursive_copy(host_dir, model_artifacts) elif container_dir == "/opt/ml/output": @@ -639,9 +642,7 @@ def __init__(self, host_dir, container_dir=None, channel=None): if container_dir and channel: raise ValueError("container_dir and channel cannot be declared together.") - self.container_dir = ( - container_dir if container_dir else os.path.join("/opt/ml/input/data", channel) - ) + self.container_dir = container_dir if container_dir else "/opt/ml/input/data/" + channel self.host_dir = host_dir if platform.system() == "Darwin" and host_dir.startswith("/var"): self.host_dir = os.path.join("/private", host_dir)