Skip to content

Commit fd28f22

Browse files
committed
chore: new implementation for listing jumpstart models
1 parent 881273b commit fd28f22

File tree

5 files changed

+1163
-352
lines changed

5 files changed

+1163
-352
lines changed

src/sagemaker/jumpstart/cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _get_manifest_key_from_model_id_semantic_version(
229229
)
230230

231231
else:
232-
possible_model_ids = [header.model_id for header in manifest.values()]
232+
possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore
233233
closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
234234
error_msg += f"Did you mean to use model ID '{closest_model_id}'?"
235235

src/sagemaker/jumpstart/filters.py

+389
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores filters related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
15+
from ast import literal_eval
16+
from enum import Enum
17+
from typing import List, Union, Any
18+
19+
from sagemaker.jumpstart.types import JumpStartDataHolderType
20+
21+
22+
class BooleanValues(str, Enum):
23+
"""Enum class for boolean values."""
24+
25+
TRUE = "true"
26+
FALSE = "false"
27+
UNKNOWN = "unknown"
28+
UNEVALUATED = "unevaluated"
29+
30+
31+
class FilterOperators(str, Enum):
32+
"""Enum class for filter operators for JumpStart models."""
33+
34+
EQUALS = "equals"
35+
NOT_EQUALS = "not_equals"
36+
IN = "in"
37+
NOT_IN = "not_in"
38+
39+
40+
FILTER_OPERATOR_STRING_MAPPINGS = {
41+
FilterOperators.EQUALS: ["===", "==", "equals", "is"],
42+
FilterOperators.NOT_EQUALS: ["!==", "!=", "not equals", "is not"],
43+
FilterOperators.IN: ["in"],
44+
FilterOperators.NOT_IN: ["not in"],
45+
}
46+
47+
48+
class Operand:
49+
"""Operand class for filtering JumpStart content."""
50+
51+
def __init__(
52+
self, unresolved_value: Any, resolved_value: BooleanValues = BooleanValues.UNEVALUATED
53+
):
54+
self.unresolved_value = unresolved_value
55+
self.resolved_value = resolved_value
56+
57+
def __iter__(self) -> Any:
58+
"""Returns an iterator."""
59+
yield self
60+
61+
def eval(self) -> None:
62+
"""Evaluates operand."""
63+
return
64+
65+
@staticmethod
66+
def validate_operand(operand: Any) -> Any:
67+
"""Validate operand and return ``Operand`` object."""
68+
if isinstance(operand, str):
69+
if operand.lower() == BooleanValues.TRUE.lower():
70+
operand = Operand(operand, resolved_value=BooleanValues.TRUE)
71+
elif operand.lower() == BooleanValues.FALSE.lower():
72+
operand = Operand(operand, resolved_value=BooleanValues.FALSE)
73+
elif operand.lower() == BooleanValues.UNKNOWN.lower():
74+
operand = Operand(operand, resolved_value=BooleanValues.UNKNOWN)
75+
else:
76+
operand = Operand(parse_filter_string(operand))
77+
elif not issubclass(type(operand), Operand):
78+
raise RuntimeError()
79+
return operand
80+
81+
82+
class Operator(Operand):
83+
"""Operator class for filtering JumpStart content."""
84+
85+
def __init__(
86+
self,
87+
resolved_value: BooleanValues = BooleanValues.UNEVALUATED,
88+
unresolved_value: Any = None,
89+
):
90+
"""Initializes ``Operator`` instance.
91+
92+
Args:
93+
resolved_value (BooleanValues): Optional. The resolved value of the operator.
94+
(Default: BooleanValues.UNEVALUATED).
95+
unresolved_value (Any): Optional. The unresolved value of the operator.
96+
(Default: None).
97+
"""
98+
super().__init__(unresolved_value=unresolved_value, resolved_value=resolved_value)
99+
100+
def eval(self) -> None:
101+
"""Evaluates operator."""
102+
return
103+
104+
def __iter__(self) -> Any:
105+
"""Returns an iterator."""
106+
yield self
107+
108+
109+
class And(Operator):
110+
"""And operator class for filtering JumpStart content."""
111+
112+
def __init__(
113+
self,
114+
*operands: Union[Operand, str],
115+
) -> None:
116+
"""Instantiates And object.
117+
118+
Args:
119+
operand (Operand): Operand for And-ing.
120+
"""
121+
self.operands: List[Operand] = list(operands) # type: ignore
122+
for i in range(len(self.operands)):
123+
self.operands[i] = Operand.validate_operand(self.operands[i])
124+
super().__init__()
125+
126+
def eval(self) -> None:
127+
"""Evaluates operator."""
128+
incomplete_expression = False
129+
for operand in self.operands:
130+
if not issubclass(type(operand), Operand):
131+
raise RuntimeError()
132+
if operand.resolved_value == BooleanValues.UNEVALUATED:
133+
operand.eval()
134+
if operand.resolved_value == BooleanValues.UNEVALUATED:
135+
raise RuntimeError()
136+
if not isinstance(operand.resolved_value, BooleanValues):
137+
raise RuntimeError()
138+
if operand.resolved_value == BooleanValues.FALSE:
139+
self.resolved_value = BooleanValues.FALSE
140+
return
141+
if operand.resolved_value == BooleanValues.UNKNOWN:
142+
incomplete_expression = True
143+
if not incomplete_expression:
144+
self.resolved_value = BooleanValues.TRUE
145+
else:
146+
self.resolved_value = BooleanValues.UNKNOWN
147+
148+
def __iter__(self) -> Any:
149+
"""Returns an iterator."""
150+
for operand in self.operands:
151+
yield from operand
152+
yield self
153+
154+
155+
class Constant(Operator):
156+
"""Constant operator class for filtering JumpStart content."""
157+
158+
def __init__(
159+
self,
160+
constant: BooleanValues,
161+
):
162+
"""Instantiates Constant operator object.
163+
164+
Args:
165+
constant (BooleanValues]): Value of constant.
166+
"""
167+
super().__init__(constant)
168+
169+
def eval(self) -> None:
170+
"""Evaluates constant"""
171+
return
172+
173+
def __iter__(self) -> Any:
174+
"""Returns an iterator."""
175+
yield self
176+
177+
178+
class Identity(Operator):
179+
"""Identity operator class for filtering JumpStart content."""
180+
181+
def __init__(
182+
self,
183+
operand: Union[Operand, str],
184+
):
185+
"""Instantiates Identity object.
186+
187+
Args:
188+
operand (Union[Operand, str]): Operand for identity operation.
189+
"""
190+
super().__init__()
191+
self.operand = Operand.validate_operand(operand)
192+
193+
def __iter__(self) -> Any:
194+
"""Returns an iterator."""
195+
yield self
196+
yield from self.operand
197+
198+
def eval(self) -> Any:
199+
"""Evaluates operator."""
200+
if not issubclass(type(self.operand), Operand):
201+
raise RuntimeError()
202+
if self.operand.resolved_value == BooleanValues.UNEVALUATED:
203+
self.operand.eval()
204+
if self.operand.resolved_value == BooleanValues.UNEVALUATED:
205+
raise RuntimeError()
206+
if not isinstance(self.operand.resolved_value, BooleanValues):
207+
raise RuntimeError(self.operand.resolved_value)
208+
self.resolved_value = self.operand.resolved_value
209+
210+
211+
class Or(Operator):
212+
"""Or operator class for filtering JumpStart content."""
213+
214+
def __init__(
215+
self,
216+
*operands: Union[Operand, str],
217+
) -> None:
218+
"""Instantiates Or object.
219+
220+
Args:
221+
operands (Operand): Operand for Or-ing.
222+
"""
223+
self.operands: List[Operand] = list(operands) # type: ignore
224+
for i in range(len(self.operands)):
225+
self.operands[i] = Operand.validate_operand(self.operands[i])
226+
super().__init__()
227+
228+
def eval(self) -> None:
229+
"""Evaluates operator."""
230+
incomplete_expression = False
231+
for operand in self.operands:
232+
if not issubclass(type(operand), Operand):
233+
raise RuntimeError()
234+
if operand.resolved_value == BooleanValues.UNEVALUATED:
235+
operand.eval()
236+
if operand.resolved_value == BooleanValues.UNEVALUATED:
237+
raise RuntimeError()
238+
if not isinstance(operand.resolved_value, BooleanValues):
239+
raise RuntimeError()
240+
if operand.resolved_value == BooleanValues.TRUE:
241+
self.resolved_value = BooleanValues.TRUE
242+
return
243+
if operand.resolved_value == BooleanValues.UNKNOWN:
244+
incomplete_expression = True
245+
if not incomplete_expression:
246+
self.resolved_value = BooleanValues.FALSE
247+
else:
248+
self.resolved_value = BooleanValues.UNKNOWN
249+
250+
def __iter__(self) -> Any:
251+
"""Returns an iterator."""
252+
for operand in self.operands:
253+
yield from operand
254+
yield self
255+
256+
257+
class Not(Operator):
258+
"""Not operator class for filtering JumpStart content."""
259+
260+
def __init__(
261+
self,
262+
operand: Union[Operand, str],
263+
) -> None:
264+
"""Instantiates Not object.
265+
266+
Args:
267+
operand (Operand): Operand for Not-ing.
268+
"""
269+
self.operand: Operand = Operand.validate_operand(operand)
270+
super().__init__()
271+
272+
def eval(self) -> None:
273+
"""Evaluates operator."""
274+
275+
if not issubclass(type(self.operand), Operand):
276+
raise RuntimeError()
277+
if self.operand.resolved_value == BooleanValues.UNEVALUATED:
278+
self.operand.eval()
279+
if self.operand.resolved_value == BooleanValues.UNEVALUATED:
280+
raise RuntimeError()
281+
if not isinstance(self.operand.resolved_value, BooleanValues):
282+
raise RuntimeError()
283+
if self.operand.resolved_value == BooleanValues.TRUE:
284+
self.resolved_value = BooleanValues.FALSE
285+
return
286+
if self.operand.resolved_value == BooleanValues.FALSE:
287+
self.resolved_value = BooleanValues.TRUE
288+
return
289+
self.resolved_value = BooleanValues.UNKNOWN
290+
291+
def __iter__(self) -> Any:
292+
"""Returns an iterator."""
293+
yield from self.operand
294+
yield self
295+
296+
297+
class ModelFilter(JumpStartDataHolderType):
298+
"""Data holder class to store model filters.
299+
300+
For a given filter string "task == ic", the key corresponds to
301+
"task" and the value corresponds to "ic", with the operation being
302+
"==".
303+
"""
304+
305+
__slots__ = ["key", "value", "operator"]
306+
307+
def __init__(self, key: str, value: str, operator: str):
308+
"""Instantiates ``ModelFilter`` object.
309+
310+
Args:
311+
key (str): The key in metadata for the model filter.
312+
value (str): The value of the metadata for the model filter.
313+
operator (str): The operator used in the model filter.
314+
"""
315+
self.key = key
316+
self.value = value
317+
self.operator = operator
318+
319+
320+
def parse_filter_string(filter_string: str) -> ModelFilter:
321+
"""Parse filter string and return a serialized ``ModelFilter`` object.
322+
323+
Args:
324+
filter_string (str): The filter string to be serialized to an object.
325+
"""
326+
327+
pad_alphabetic_operator = (
328+
lambda operator: " " + operator + " "
329+
if any(character.isalpha() for character in operator)
330+
else operator
331+
)
332+
333+
acceptable_operators_in_parse_order = (
334+
list(
335+
map(
336+
pad_alphabetic_operator, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]
337+
)
338+
)
339+
+ list(
340+
map(pad_alphabetic_operator, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN])
341+
)
342+
+ list(
343+
map(pad_alphabetic_operator, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS])
344+
)
345+
+ list(map(pad_alphabetic_operator, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]))
346+
)
347+
for operator in acceptable_operators_in_parse_order:
348+
split_filter_string = filter_string.split(operator)
349+
if len(split_filter_string) == 2:
350+
return ModelFilter(
351+
split_filter_string[0].strip(), split_filter_string[1].strip(), operator.strip()
352+
)
353+
raise RuntimeError(f"Cannot parse filter string: {filter_string}")
354+
355+
356+
def evaluate_filter_expression( # pylint: disable=too-many-return-statements
357+
model_filter: ModelFilter, cached_model_value: Any
358+
) -> BooleanValues:
359+
"""Evaluates model filter with cached model spec value, returns boolean.
360+
361+
Args:
362+
model_filter (ModelFilter): The model filter for evaluation.
363+
cached_model_value (Any): The value in the model manifest/spec that should be used to
364+
evaluate the filter.
365+
"""
366+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]:
367+
model_filter_value = model_filter.value
368+
if isinstance(cached_model_value, bool):
369+
cached_model_value = str(cached_model_value).lower()
370+
model_filter_value = model_filter.value.lower()
371+
if str(model_filter_value) == str(cached_model_value):
372+
return BooleanValues.TRUE
373+
return BooleanValues.FALSE
374+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]:
375+
if isinstance(cached_model_value, bool):
376+
cached_model_value = str(cached_model_value).lower()
377+
model_filter.value = model_filter.value.lower()
378+
if str(model_filter.value) == str(cached_model_value):
379+
return BooleanValues.FALSE
380+
return BooleanValues.TRUE
381+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]:
382+
if cached_model_value in literal_eval(model_filter.value):
383+
return BooleanValues.TRUE
384+
return BooleanValues.FALSE
385+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]:
386+
if cached_model_value in literal_eval(model_filter.value):
387+
return BooleanValues.FALSE
388+
return BooleanValues.TRUE
389+
raise RuntimeError(f"Bad operator: {model_filter.operator}")

0 commit comments

Comments
 (0)