Skip to content

Commit 9e34237

Browse files
staubhpPayton Staub
and
Payton Staub
authored
fix: Support file URIs in ProcessingStep's code parameter (#3051)
* fix: Support file URIs in ProcessingStep's code parameter * Don't strip leading slash from file uri Co-authored-by: Payton Staub <[email protected]>
1 parent 8ed18fe commit 9e34237

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

src/sagemaker/workflow/utilities.py

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from typing import List, Sequence, Union
1717
import hashlib
18+
from urllib.parse import unquote, urlparse
1819

1920
from sagemaker.workflow.entities import (
2021
Entity,
@@ -50,6 +51,8 @@ def hash_file(path: str) -> str:
5051
"""
5152
BUF_SIZE = 65536 # read in 64KiB chunks
5253
md5 = hashlib.md5()
54+
if path.lower().startswith("file://"):
55+
path = unquote(urlparse(path).path)
5356
with open(path, "rb") as f:
5457
while True:
5558
data = f.read(BUF_SIZE)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# language governing permissions and limitations under the License.
14+
from __future__ import absolute_import
15+
16+
import tempfile
17+
from sagemaker.workflow.utilities import hash_file
18+
19+
20+
def test_hash_file():
21+
with tempfile.NamedTemporaryFile() as tmp:
22+
tmp.write("hashme".encode())
23+
hash = hash_file(tmp.name)
24+
assert hash == "d41d8cd98f00b204e9800998ecf8427e"
25+
26+
27+
def test_hash_file_uri():
28+
with tempfile.NamedTemporaryFile() as tmp:
29+
tmp.write("hashme".encode())
30+
hash = hash_file(f"file:///{tmp.name}")
31+
assert hash == "d41d8cd98f00b204e9800998ecf8427e"

0 commit comments

Comments
 (0)