diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 3e77465ff6..16a832a14f 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -15,6 +15,7 @@ from typing import List, Sequence, Union import hashlib +from urllib.parse import unquote, urlparse from sagemaker.workflow.entities import ( Entity, @@ -50,6 +51,8 @@ def hash_file(path: str) -> str: """ BUF_SIZE = 65536 # read in 64KiB chunks md5 = hashlib.md5() + if path.lower().startswith("file://"): + path = unquote(urlparse(path).path) with open(path, "rb") as f: while True: data = f.read(BUF_SIZE) diff --git a/tests/unit/sagemaker/workflow/test_utilities.py b/tests/unit/sagemaker/workflow/test_utilities.py new file mode 100644 index 0000000000..d128d8b31d --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_utilities.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import tempfile +from sagemaker.workflow.utilities import hash_file + + +def test_hash_file(): + with tempfile.NamedTemporaryFile() as tmp: + tmp.write("hashme".encode()) + hash = hash_file(tmp.name) + assert hash == "d41d8cd98f00b204e9800998ecf8427e" + + +def test_hash_file_uri(): + with tempfile.NamedTemporaryFile() as tmp: + tmp.write("hashme".encode()) + hash = hash_file(f"file:///{tmp.name}") + assert hash == "d41d8cd98f00b204e9800998ecf8427e"