|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +from __future__ import absolute_import |
| 14 | + |
| 15 | +from enum import Enum |
| 16 | +from typing import Dict |
| 17 | +from typing import Optional |
| 18 | +from typing import Union |
| 19 | +import os |
| 20 | + |
| 21 | + |
| 22 | +def _to_s3_path(filename: str, s3_folder: Optional[str]) -> str: |
| 23 | + return filename if not s3_folder else f"{s3_folder}/{filename}" |
| 24 | + |
| 25 | + |
| 26 | +_NB_ASSETS_S3_FOLDER = "inference-notebook-assets" |
| 27 | +_TF_FLOWERS_S3_FOLDER = "training-datasets/tf_flowers" |
| 28 | + |
| 29 | +TMP_DIRECTORY_PATH = os.path.join( |
| 30 | + os.path.abspath(os.path.join(os.path.abspath(__file__), os.pardir)), "tmp" |
| 31 | +) |
| 32 | + |
| 33 | +ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID = "JUMPSTART_SDK_TEST_SUITE_ID" |
| 34 | + |
| 35 | +JUMPSTART_TAG = "JumpStart-SDK-Integ-Test-Suite-Id" |
| 36 | + |
| 37 | +HYPERPARAMETER_MODEL_DICT = { |
| 38 | + ("huggingface-spc-bert-base-cased", "*"): { |
| 39 | + "epochs": "1", |
| 40 | + "adam-learning-rate": "2e-05", |
| 41 | + "batch-size": "8", |
| 42 | + "sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz", |
| 43 | + "sagemaker_program": "transfer_learning.py", |
| 44 | + "sagemaker_container_log_level": "20", |
| 45 | + }, |
| 46 | +} |
| 47 | + |
| 48 | +TRAINING_DATASET_MODEL_DICT = { |
| 49 | + ("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"), |
| 50 | +} |
| 51 | + |
| 52 | + |
| 53 | +class ContentType(str, Enum): |
| 54 | + """Possible value for content type argument of SageMakerRuntime.invokeEndpoint.""" |
| 55 | + |
| 56 | + X_IMAGE = "application/x-image" |
| 57 | + LIST_TEXT = "application/list-text" |
| 58 | + X_TEXT = "application/x-text" |
| 59 | + TEXT_CSV = "text/csv" |
| 60 | + |
| 61 | + |
| 62 | +class InferenceImageFilename(str, Enum): |
| 63 | + """Filename of the inference asset in JumpStart distribution buckets.""" |
| 64 | + |
| 65 | + DOG = "dog.jpg" |
| 66 | + CAT = "cat.jpg" |
| 67 | + DAISY = "100080576_f52e8ee070_n.jpg" |
| 68 | + DAISY_2 = "10140303196_b88d3d6cec.jpg" |
| 69 | + ROSE = "102501987_3cdb8e5394_n.jpg" |
| 70 | + NAXOS_TAVERNA = "Naxos_Taverna.jpg" |
| 71 | + PEDESTRIAN = "img_pedestrian.png" |
| 72 | + |
| 73 | + |
| 74 | +class InferenceTabularDataname(str, Enum): |
| 75 | + """Filename of the tabular data example in JumpStart distribution buckets.""" |
| 76 | + |
| 77 | + REGRESSION_ONEHOT = "regressonehot_data.csv" |
| 78 | + REGRESSION = "regress_data.csv" |
| 79 | + MULTICLASS = "multiclass_data.csv" |
| 80 | + |
| 81 | + |
| 82 | +class ClassLabelFile(str, Enum): |
| 83 | + """Filename in JumpStart distribution buckets for the map of the class index to human readable labels.""" |
| 84 | + |
| 85 | + IMAGE_NET = "ImageNetLabels.txt" |
| 86 | + |
| 87 | + |
| 88 | +TEST_ASSETS_SPECS: Dict[ |
| 89 | + Union[InferenceImageFilename, InferenceTabularDataname, ClassLabelFile], str |
| 90 | +] = { |
| 91 | + InferenceImageFilename.DOG: _to_s3_path(InferenceImageFilename.DOG, _NB_ASSETS_S3_FOLDER), |
| 92 | + InferenceImageFilename.CAT: _to_s3_path(InferenceImageFilename.CAT, _NB_ASSETS_S3_FOLDER), |
| 93 | + InferenceImageFilename.DAISY: _to_s3_path( |
| 94 | + InferenceImageFilename.DAISY, f"{_TF_FLOWERS_S3_FOLDER}/daisy" |
| 95 | + ), |
| 96 | + InferenceImageFilename.DAISY_2: _to_s3_path( |
| 97 | + InferenceImageFilename.DAISY_2, f"{_TF_FLOWERS_S3_FOLDER}/daisy" |
| 98 | + ), |
| 99 | + InferenceImageFilename.ROSE: _to_s3_path( |
| 100 | + InferenceImageFilename.ROSE, f"{_TF_FLOWERS_S3_FOLDER}/roses" |
| 101 | + ), |
| 102 | + InferenceImageFilename.NAXOS_TAVERNA: _to_s3_path( |
| 103 | + InferenceImageFilename.NAXOS_TAVERNA, _NB_ASSETS_S3_FOLDER |
| 104 | + ), |
| 105 | + InferenceImageFilename.PEDESTRIAN: _to_s3_path( |
| 106 | + InferenceImageFilename.PEDESTRIAN, _NB_ASSETS_S3_FOLDER |
| 107 | + ), |
| 108 | + ClassLabelFile.IMAGE_NET: _to_s3_path(ClassLabelFile.IMAGE_NET, _NB_ASSETS_S3_FOLDER), |
| 109 | + InferenceTabularDataname.REGRESSION_ONEHOT: _to_s3_path( |
| 110 | + InferenceTabularDataname.REGRESSION_ONEHOT, _NB_ASSETS_S3_FOLDER |
| 111 | + ), |
| 112 | + InferenceTabularDataname.REGRESSION: _to_s3_path( |
| 113 | + InferenceTabularDataname.REGRESSION, _NB_ASSETS_S3_FOLDER |
| 114 | + ), |
| 115 | + InferenceTabularDataname.MULTICLASS: _to_s3_path( |
| 116 | + InferenceTabularDataname.MULTICLASS, _NB_ASSETS_S3_FOLDER |
| 117 | + ), |
| 118 | +} |
0 commit comments