Skip to content

Commit 80b3e08

Browse files
authored
feat: parallelize notebook search utils, add new operators (#4342)
* feat: parallelize notebook search utils * chore: raise exception in notebook utils if thread has error * chore: improve variable name * fix: not passing region to get jumpstart bucket * chore: add sagemaker session to notebook utils * chore: address PR comments * feat: add support for includes, begins with, ends with * fix: pylint * feat: private util for model eula key * fix: unit tests, use verify_model_region_and_return_specs in notebook utils * Revert "feat: private util for model eula key" This reverts commit e2daefc. * chore: add search keywords to header
1 parent ae50026 commit 80b3e08

File tree

5 files changed

+464
-264
lines changed

5 files changed

+464
-264
lines changed

src/sagemaker/jumpstart/filters.py

+132-32
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515
from ast import literal_eval
1616
from enum import Enum
17-
from typing import Dict, List, Union, Any
17+
from typing import Dict, List, Optional, Union, Any
1818

1919
from sagemaker.jumpstart.types import JumpStartDataHolderType
2020

@@ -38,6 +38,10 @@ class FilterOperators(str, Enum):
3838
NOT_EQUALS = "not_equals"
3939
IN = "in"
4040
NOT_IN = "not_in"
41+
INCLUDES = "includes"
42+
NOT_INCLUDES = "not_includes"
43+
BEGINS_WITH = "begins_with"
44+
ENDS_WITH = "ends_with"
4145

4246

4347
class SpecialSupportedFilterKeys(str, Enum):
@@ -52,6 +56,10 @@ class SpecialSupportedFilterKeys(str, Enum):
5256
FilterOperators.NOT_EQUALS: ["!==", "!=", "not equals", "is not"],
5357
FilterOperators.IN: ["in"],
5458
FilterOperators.NOT_IN: ["not in"],
59+
FilterOperators.INCLUDES: ["includes", "contains"],
60+
FilterOperators.NOT_INCLUDES: ["not includes", "not contains"],
61+
FilterOperators.BEGINS_WITH: ["begins with", "starts with"],
62+
FilterOperators.ENDS_WITH: ["ends with"],
5563
}
5664

5765

@@ -62,7 +70,19 @@ class SpecialSupportedFilterKeys(str, Enum):
6270
)
6371

6472
ACCEPTABLE_OPERATORS_IN_PARSE_ORDER = (
65-
list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]))
73+
list(
74+
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.BEGINS_WITH])
75+
)
76+
+ list(
77+
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.ENDS_WITH])
78+
)
79+
+ list(
80+
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_INCLUDES])
81+
)
82+
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.INCLUDES]))
83+
+ list(
84+
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS])
85+
)
6686
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]))
6787
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]))
6888
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]))
@@ -428,9 +448,96 @@ def parse_filter_string(filter_string: str) -> ModelFilter:
428448
raise ValueError(f"Cannot parse filter string: {filter_string}")
429449

430450

451+
def _negate_boolean(boolean: BooleanValues) -> BooleanValues:
452+
"""Negates boolean expression (False -> True, True -> False)."""
453+
if boolean == BooleanValues.TRUE:
454+
return BooleanValues.FALSE
455+
if boolean == BooleanValues.FALSE:
456+
return BooleanValues.TRUE
457+
return boolean
458+
459+
460+
def _evaluate_filter_expression_equals(
461+
model_filter: ModelFilter,
462+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
463+
) -> BooleanValues:
464+
"""Evaluates filter expressions for equals."""
465+
if cached_model_value is None:
466+
return BooleanValues.FALSE
467+
model_filter_value = model_filter.value
468+
if isinstance(cached_model_value, bool):
469+
cached_model_value = str(cached_model_value).lower()
470+
model_filter_value = model_filter.value.lower()
471+
if str(model_filter_value) == str(cached_model_value):
472+
return BooleanValues.TRUE
473+
return BooleanValues.FALSE
474+
475+
476+
def _evaluate_filter_expression_in(
477+
model_filter: ModelFilter,
478+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
479+
) -> BooleanValues:
480+
"""Evaluates filter expressions for string/list in."""
481+
if cached_model_value is None:
482+
return BooleanValues.FALSE
483+
py_obj = model_filter.value
484+
try:
485+
py_obj = literal_eval(py_obj)
486+
try:
487+
iter(py_obj)
488+
except TypeError:
489+
return BooleanValues.FALSE
490+
except Exception: # pylint: disable=W0703
491+
pass
492+
if isinstance(cached_model_value, list):
493+
return BooleanValues.FALSE
494+
if cached_model_value in py_obj:
495+
return BooleanValues.TRUE
496+
return BooleanValues.FALSE
497+
498+
499+
def _evaluate_filter_expression_includes(
500+
model_filter: ModelFilter,
501+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
502+
) -> BooleanValues:
503+
"""Evaluates filter expressions for string includes."""
504+
if cached_model_value is None:
505+
return BooleanValues.FALSE
506+
filter_value = str(model_filter.value)
507+
if filter_value in cached_model_value:
508+
return BooleanValues.TRUE
509+
return BooleanValues.FALSE
510+
511+
512+
def _evaluate_filter_expression_begins_with(
513+
model_filter: ModelFilter,
514+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
515+
) -> BooleanValues:
516+
"""Evaluates filter expressions for string begins with."""
517+
if cached_model_value is None:
518+
return BooleanValues.FALSE
519+
filter_value = str(model_filter.value)
520+
if cached_model_value.startswith(filter_value):
521+
return BooleanValues.TRUE
522+
return BooleanValues.FALSE
523+
524+
525+
def _evaluate_filter_expression_ends_with(
526+
model_filter: ModelFilter,
527+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
528+
) -> BooleanValues:
529+
"""Evaluates filter expressions for string ends with."""
530+
if cached_model_value is None:
531+
return BooleanValues.FALSE
532+
filter_value = str(model_filter.value)
533+
if cached_model_value.endswith(filter_value):
534+
return BooleanValues.TRUE
535+
return BooleanValues.FALSE
536+
537+
431538
def evaluate_filter_expression( # pylint: disable=too-many-return-statements
432539
model_filter: ModelFilter,
433-
cached_model_value: Union[str, bool, int, float, Dict[str, Any], List[Any]],
540+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
434541
) -> BooleanValues:
435542
"""Evaluates model filter with cached model spec value, returns boolean.
436543
@@ -440,36 +547,29 @@ def evaluate_filter_expression( # pylint: disable=too-many-return-statements
440547
evaluate the filter.
441548
"""
442549
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]:
443-
model_filter_value = model_filter.value
444-
if isinstance(cached_model_value, bool):
445-
cached_model_value = str(cached_model_value).lower()
446-
model_filter_value = model_filter.value.lower()
447-
if str(model_filter_value) == str(cached_model_value):
448-
return BooleanValues.TRUE
449-
return BooleanValues.FALSE
550+
return _evaluate_filter_expression_equals(model_filter, cached_model_value)
551+
450552
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]:
451-
if isinstance(cached_model_value, bool):
452-
cached_model_value = str(cached_model_value).lower()
453-
model_filter.value = model_filter.value.lower()
454-
if str(model_filter.value) == str(cached_model_value):
455-
return BooleanValues.FALSE
456-
return BooleanValues.TRUE
553+
return _negate_boolean(_evaluate_filter_expression_equals(model_filter, cached_model_value))
554+
457555
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]:
458-
py_obj = literal_eval(model_filter.value)
459-
try:
460-
iter(py_obj)
461-
except TypeError:
462-
return BooleanValues.FALSE
463-
if cached_model_value in py_obj:
464-
return BooleanValues.TRUE
465-
return BooleanValues.FALSE
556+
return _evaluate_filter_expression_in(model_filter, cached_model_value)
557+
466558
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]:
467-
py_obj = literal_eval(model_filter.value)
468-
try:
469-
iter(py_obj)
470-
except TypeError:
471-
return BooleanValues.TRUE
472-
if cached_model_value in py_obj:
473-
return BooleanValues.FALSE
474-
return BooleanValues.TRUE
559+
return _negate_boolean(_evaluate_filter_expression_in(model_filter, cached_model_value))
560+
561+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.INCLUDES]:
562+
return _evaluate_filter_expression_includes(model_filter, cached_model_value)
563+
564+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_INCLUDES]:
565+
return _negate_boolean(
566+
_evaluate_filter_expression_includes(model_filter, cached_model_value)
567+
)
568+
569+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.BEGINS_WITH]:
570+
return _evaluate_filter_expression_begins_with(model_filter, cached_model_value)
571+
572+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.ENDS_WITH]:
573+
return _evaluate_filter_expression_ends_with(model_filter, cached_model_value)
574+
475575
raise RuntimeError(f"Bad operator: {model_filter.operator}")

0 commit comments

Comments
 (0)