Skip to content

Commit 19dbb0c

Browse files
committed
added check of the s3 URIs
1 parent 4db22a0 commit 19dbb0c

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

src/sagemaker/clarify.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,26 @@
2929

3030
from schema import Schema, And, Use, Or, Optional, Regex
3131

32-
from sagemaker import image_uris, s3, utils, Session
32+
from sagemaker import image_uris, s3, utils
33+
from sagemaker.session import Session
3334
from sagemaker.network import NetworkConfig
3435
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
3536

3637
logger = logging.getLogger(__name__)
3738

3839

39-
ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])"
40+
ENDPOINT_NAME_PREFIX_PATTERN = r"^[a-zA-Z0-9](-*[a-zA-Z0-9])"
41+
MODEL_NAME_PATTERN = r"^[a-zA-Z0-9]([\-a-zA-Z0-9]*[a-zA-Z0-9])?"
42+
43+
44+
def _validate_s3_path(path: str) -> str:
45+
"""Validates s3 path is correct"""
46+
prefix = "s3://"
47+
assert path.startswith(prefix)
48+
49+
assert "//" not in path[len(prefix) :]
50+
assert not path.startswith(f"{prefix}/")
51+
return path
4052

4153

4254
ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema(
@@ -54,7 +66,7 @@
5466
"application/x-image",
5567
),
5668
),
57-
Optional("dataset_uri"): str,
69+
Optional("dataset_uri"): And(str, _validate_s3_path),
5870
Optional("headers"): [str],
5971
Optional("label"): Or(str, int),
6072
# this field indicates user provides predicted_label in dataset
@@ -65,9 +77,9 @@
6577
Optional("facet"): [
6678
{"name_or_index": Or(str, int), Optional("value_or_threshold"): [Or(int, float, str)]}
6779
],
68-
Optional("facet_dataset_uri"): str,
80+
Optional("facet_dataset_uri"): And(str, _validate_s3_path),
6981
Optional("facet_headers"): [str],
70-
Optional("predicted_label_dataset_uri"): str,
82+
Optional("predicted_label_dataset_uri"): And(str, _validate_s3_path),
7183
Optional("predicted_label_headers"): [str],
7284
Optional("excluded_columns"): [Or(int, str)],
7385
Optional("joinsource_name_or_index"): Or(str, int),
@@ -82,11 +94,11 @@
8294
Or(
8395
# CSV row
8496
[Or(int, float, str, None)],
85-
# JSON row (any JSON object). As I write this only
86-
# SageMaker JSONLines Dense Format ([1])
87-
# is supported and the validation is NOT done
88-
# by the schema but by the data loader.
89-
# [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines
97+
# JSON row (any JSON object). As I write this only
98+
# SageMaker JSONLines Dense Format ([1])
99+
# is supported and the validation is NOT done
100+
# by the schema but by the data loader.
101+
# [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines
90102
{object: object},
91103
)
92104
],
@@ -266,7 +278,7 @@
266278
Optional("predictor"): {
267279
Optional("endpoint_name"): str,
268280
Optional("endpoint_name_prefix"): And(str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)),
269-
Optional("model_name"): str,
281+
Optional("model_name"): And(str, Regex(MODEL_NAME_PATTERN)),
270282
Optional("target_model"): str,
271283
Optional("instance_type"): str,
272284
Optional("initial_instance_count"): int,

tests/unit/test_clarify.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
TextConfig,
3131
ImageConfig,
3232
_AnalysisConfigGenerator,
33+
_validate_s3_path,
3334
)
3435

3536
JOB_NAME_PREFIX = "my-prefix"
@@ -42,6 +43,21 @@ def test_uri():
4243
assert "306415355426.dkr.ecr.us-west-2.amazonaws.com/sagemaker-clarify-processing:1.0" == uri
4344

4445

46+
def test_validated_s3_path():
47+
# Success
48+
path = "s3://some_path/key"
49+
assert _validate_s3_path(path) == path
50+
51+
with pytest.raises(AssertionError):
52+
_validate_s3_path("wrong-prefix://some_path/key")
53+
54+
with pytest.raises(AssertionError):
55+
_validate_s3_path("s3://some_path//key")
56+
57+
with pytest.raises(AssertionError):
58+
_validate_s3_path("s3:///some_path/key")
59+
60+
4561
def test_data_config():
4662
# facets in input dataset
4763
s3_data_input_path = "s3://path/to/input.csv"

0 commit comments

Comments
 (0)