Skip to content

feat: parallelize notebook search utils, add new operators #4342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 132 additions & 32 deletions src/sagemaker/jumpstart/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import
from ast import literal_eval
from enum import Enum
from typing import Dict, List, Union, Any
from typing import Dict, List, Optional, Union, Any

from sagemaker.jumpstart.types import JumpStartDataHolderType

Expand All @@ -38,6 +38,10 @@ class FilterOperators(str, Enum):
NOT_EQUALS = "not_equals"
IN = "in"
NOT_IN = "not_in"
INCLUDES = "includes"
NOT_INCLUDES = "not_includes"
BEGINS_WITH = "begins_with"
ENDS_WITH = "ends_with"


class SpecialSupportedFilterKeys(str, Enum):
Expand All @@ -52,6 +56,10 @@ class SpecialSupportedFilterKeys(str, Enum):
FilterOperators.NOT_EQUALS: ["!==", "!=", "not equals", "is not"],
FilterOperators.IN: ["in"],
FilterOperators.NOT_IN: ["not in"],
FilterOperators.INCLUDES: ["includes", "contains"],
FilterOperators.NOT_INCLUDES: ["not includes", "not contains"],
FilterOperators.BEGINS_WITH: ["begins with", "starts with"],
FilterOperators.ENDS_WITH: ["ends with"],
}


Expand All @@ -62,7 +70,19 @@ class SpecialSupportedFilterKeys(str, Enum):
)

ACCEPTABLE_OPERATORS_IN_PARSE_ORDER = (
list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]))
list(
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.BEGINS_WITH])
)
+ list(
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.ENDS_WITH])
)
+ list(
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_INCLUDES])
)
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.INCLUDES]))
+ list(
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS])
)
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]))
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]))
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]))
Expand Down Expand Up @@ -428,9 +448,96 @@ def parse_filter_string(filter_string: str) -> ModelFilter:
raise ValueError(f"Cannot parse filter string: {filter_string}")


def _negate_boolean(boolean: BooleanValues) -> BooleanValues:
"""Negates boolean expression (False -> True, True -> False)."""
if boolean == BooleanValues.TRUE:
return BooleanValues.FALSE
if boolean == BooleanValues.FALSE:
return BooleanValues.TRUE
return boolean


def _evaluate_filter_expression_equals(
model_filter: ModelFilter,
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates filter expressions for equals."""
if cached_model_value is None:
return BooleanValues.FALSE
model_filter_value = model_filter.value
if isinstance(cached_model_value, bool):
cached_model_value = str(cached_model_value).lower()
model_filter_value = model_filter.value.lower()
if str(model_filter_value) == str(cached_model_value):
return BooleanValues.TRUE
return BooleanValues.FALSE


def _evaluate_filter_expression_in(
model_filter: ModelFilter,
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates filter expressions for string/list in."""
if cached_model_value is None:
return BooleanValues.FALSE
py_obj = model_filter.value
try:
py_obj = literal_eval(py_obj)
try:
iter(py_obj)
except TypeError:
return BooleanValues.FALSE
except Exception: # pylint: disable=W0703
pass
if isinstance(cached_model_value, list):
return BooleanValues.FALSE
if cached_model_value in py_obj:
return BooleanValues.TRUE
return BooleanValues.FALSE


def _evaluate_filter_expression_includes(
model_filter: ModelFilter,
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates filter expressions for string includes."""
if cached_model_value is None:
return BooleanValues.FALSE
filter_value = str(model_filter.value)
if filter_value in cached_model_value:
return BooleanValues.TRUE
return BooleanValues.FALSE


def _evaluate_filter_expression_begins_with(
model_filter: ModelFilter,
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates filter expressions for string begins with."""
if cached_model_value is None:
return BooleanValues.FALSE
filter_value = str(model_filter.value)
if cached_model_value.startswith(filter_value):
return BooleanValues.TRUE
return BooleanValues.FALSE


def _evaluate_filter_expression_ends_with(
model_filter: ModelFilter,
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates filter expressions for string ends with."""
if cached_model_value is None:
return BooleanValues.FALSE
filter_value = str(model_filter.value)
if cached_model_value.endswith(filter_value):
return BooleanValues.TRUE
return BooleanValues.FALSE


def evaluate_filter_expression( # pylint: disable=too-many-return-statements
model_filter: ModelFilter,
cached_model_value: Union[str, bool, int, float, Dict[str, Any], List[Any]],
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
) -> BooleanValues:
"""Evaluates model filter with cached model spec value, returns boolean.

Expand All @@ -440,36 +547,29 @@ def evaluate_filter_expression( # pylint: disable=too-many-return-statements
evaluate the filter.
"""
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]:
model_filter_value = model_filter.value
if isinstance(cached_model_value, bool):
cached_model_value = str(cached_model_value).lower()
model_filter_value = model_filter.value.lower()
if str(model_filter_value) == str(cached_model_value):
return BooleanValues.TRUE
return BooleanValues.FALSE
return _evaluate_filter_expression_equals(model_filter, cached_model_value)

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]:
if isinstance(cached_model_value, bool):
cached_model_value = str(cached_model_value).lower()
model_filter.value = model_filter.value.lower()
if str(model_filter.value) == str(cached_model_value):
return BooleanValues.FALSE
return BooleanValues.TRUE
return _negate_boolean(_evaluate_filter_expression_equals(model_filter, cached_model_value))

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]:
py_obj = literal_eval(model_filter.value)
try:
iter(py_obj)
except TypeError:
return BooleanValues.FALSE
if cached_model_value in py_obj:
return BooleanValues.TRUE
return BooleanValues.FALSE
return _evaluate_filter_expression_in(model_filter, cached_model_value)

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]:
py_obj = literal_eval(model_filter.value)
try:
iter(py_obj)
except TypeError:
return BooleanValues.TRUE
if cached_model_value in py_obj:
return BooleanValues.FALSE
return BooleanValues.TRUE
return _negate_boolean(_evaluate_filter_expression_in(model_filter, cached_model_value))

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.INCLUDES]:
return _evaluate_filter_expression_includes(model_filter, cached_model_value)

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_INCLUDES]:
return _negate_boolean(
_evaluate_filter_expression_includes(model_filter, cached_model_value)
)

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.BEGINS_WITH]:
return _evaluate_filter_expression_begins_with(model_filter, cached_model_value)

if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.ENDS_WITH]:
return _evaluate_filter_expression_ends_with(model_filter, cached_model_value)

raise RuntimeError(f"Bad operator: {model_filter.operator}")
Loading