Skip to content

Commit 6c4e2e6

Browse files
authored
Change sorting to work similar to the generic filtering (#285)
* Change sorting to work similar to the generic filtering - this makes it possible to sort with different tables. * Fix typo CallableErrorHander to CallableErrorHandler * Change graphql sorting to use SortOrder instead of a new graphql enum * Fix graphql subscriptions tests
1 parent 6b9e7d8 commit 6c4e2e6

File tree

16 files changed

+318
-188
lines changed

16 files changed

+318
-188
lines changed

orchestrator/db/filters/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from orchestrator.db.filters.filters import (
2-
CallableErrorHander,
2+
CallableErrorHandler,
33
Filter,
44
generic_apply_filters,
55
generic_filter,
@@ -8,7 +8,7 @@
88

99
__all__ = [
1010
"Filter",
11-
"CallableErrorHander",
11+
"CallableErrorHandler",
1212
"generic_filter",
1313
"generic_apply_filters",
1414
"generic_filters_validate",

orchestrator/db/filters/filters.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from orchestrator.db.database import SearchQuery
2121

2222

23-
class CallableErrorHander(Protocol):
23+
class CallableErrorHandler(Protocol):
2424
def __call__(self, message: str, **kwargs: Any) -> None:
2525
...
2626

@@ -48,9 +48,9 @@ def _is_valid_filter(item: Filter) -> bool:
4848

4949
def generic_apply_filters(
5050
valid_filter_functions_by_column: ValidFilterFunctionsByColumnType,
51-
) -> Callable[[QueryType, Iterator[Filter], CallableErrorHander], QueryType]:
51+
) -> Callable[[QueryType, Iterator[Filter], CallableErrorHandler], QueryType]:
5252
def _apply_filters(
53-
query: QueryType, filter_by: Iterator[Filter], handle_filter_error: CallableErrorHander
53+
query: QueryType, filter_by: Iterator[Filter], handle_filter_error: CallableErrorHandler
5454
) -> QueryType:
5555
for item in filter_by:
5656
field = item.field
@@ -76,15 +76,15 @@ def _apply_filters(
7676

7777
def generic_filter(
7878
valid_filter_functions_by_column: ValidFilterFunctionsByColumnType,
79-
) -> Callable[[QueryType, list[Filter], CallableErrorHander], QueryType]:
79+
) -> Callable[[QueryType, list[Filter], CallableErrorHandler], QueryType]:
8080
valid_filter_functions_by_column_KEYS = list(valid_filter_functions_by_column.keys())
8181
_validate_filters = generic_filters_validate(valid_filter_functions_by_column)
8282
_apply_filters = generic_apply_filters(valid_filter_functions_by_column)
8383

8484
def _filter(
8585
query: QueryType,
8686
filter_by: list[Filter],
87-
handle_filter_error: CallableErrorHander,
87+
handle_filter_error: CallableErrorHandler,
8888
) -> QueryType:
8989
invalid_filter_items, valid_filter_items = _validate_filters(filter_by)
9090
if invalid_list := [item.dict() for item in invalid_filter_items]:

orchestrator/db/sorting/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from orchestrator.db.sorting.process import sort_processes
2-
from orchestrator.db.sorting.sorting import Sort, SortOrder, generic_apply_sorts, generic_sort, generic_sorts_validate
2+
from orchestrator.db.sorting.sorting import Sort, SortOrder, generic_apply_sorting, generic_sort, generic_sorts_validate
33
from orchestrator.db.sorting.subscription import sort_subscriptions
44

55
__all__ = [
66
"Sort",
77
"SortOrder",
88
"generic_sort",
9-
"generic_apply_sorts",
9+
"generic_apply_sorting",
1010
"generic_sorts_validate",
1111
"sort_processes",
1212
"sort_subscriptions",

orchestrator/db/sorting/process.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
from sqlalchemy.inspection import inspect
2+
13
from orchestrator.db import ProcessTable
2-
from orchestrator.db.sorting.sorting import generic_sort
4+
from orchestrator.db.sorting.sorting import generic_column_sort, generic_sort
5+
from orchestrator.utils.helpers import to_camel
36

47
VALID_SORT_KEY_MAP = {
5-
"creator": "created_by",
6-
"started": "started_at",
7-
"status": "last_status",
8-
"assignee": "assignee",
9-
"modified": "last_modified_at",
10-
"workflow": "workflow",
8+
"created_by": "created_by",
9+
"started_at": "started",
10+
"last_status": "status",
11+
"last_modified_at": "modified",
12+
}
13+
PROCESS_SORT_FUNCTIONS_BY_COLUMN = {
14+
to_camel(VALID_SORT_KEY_MAP.get(key, key)): generic_column_sort(value)
15+
for [key, value] in inspect(ProcessTable).columns.items()
1116
}
1217

13-
sort_processes = generic_sort(VALID_SORT_KEY_MAP, ProcessTable)
18+
sort_processes = generic_sort(PROCESS_SORT_FUNCTIONS_BY_COLUMN)

orchestrator/db/sorting/product.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1+
from sqlalchemy.inspection import inspect
2+
13
from orchestrator.db import ProductTable
2-
from orchestrator.db.sorting.sorting import generic_sort
4+
from orchestrator.db.sorting.sorting import generic_column_sort, generic_sort
5+
from orchestrator.utils.helpers import to_camel
36

4-
VALID_SORT_KEY_MAP = {
5-
"created_at": "created_at",
6-
"end_date": "end_date",
7-
"status": "status",
8-
"product_type": "product_type",
9-
"name": "name",
10-
"tag": "tag",
7+
PRODUCT_SORT_FUNCTIONS_BY_COLUMN = {
8+
to_camel(key): generic_column_sort(value) for [key, value] in inspect(ProductTable).columns.items()
119
}
1210

13-
sort_products = generic_sort(VALID_SORT_KEY_MAP, ProductTable)
11+
sort_products = generic_sort(PRODUCT_SORT_FUNCTIONS_BY_COLUMN)

orchestrator/db/sorting/sorting.py

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@
1414
from enum import Enum
1515
from typing import Callable, Iterator, TypeVar
1616

17+
import strawberry
1718
from more_itertools import partition
1819
from pydantic import BaseModel
20+
from sqlalchemy import Column
1921
from sqlalchemy.sql import expression
2022

21-
from orchestrator.db.database import BaseModel as SqlBaseModel
23+
from orchestrator.api.error_handling import ProblemDetailException
2224
from orchestrator.db.database import SearchQuery
23-
from orchestrator.db.filters import CallableErrorHander
25+
from orchestrator.db.filters import CallableErrorHandler
2426

2527

28+
@strawberry.enum(description="Sort order (ASC or DESC)")
2629
class SortOrder(Enum):
2730
ASC = "asc"
2831
DESC = "desc"
@@ -32,16 +35,14 @@ class Sort(BaseModel):
3235
field: str
3336
order: SortOrder
3437

35-
class Config:
36-
use_enum_values = True
37-
3838

3939
GenericType = TypeVar("GenericType")
4040
QueryType = SearchQuery
41+
ValidSortFunctionsByColumnType = dict[str, Callable[[QueryType, SortOrder], QueryType]]
4142

4243

4344
def generic_sorts_validate(
44-
valid_sort_dict: dict[str, str]
45+
valid_sort_functions_by_column: ValidSortFunctionsByColumnType,
4546
) -> Callable[[list[Sort]], tuple[Iterator[Sort], Iterator[Sort]]]:
4647
"""Create generic validate sort factory that creates a validate function based on the valid sort dict.
4748
@@ -54,51 +55,69 @@ def generic_sorts_validate(
5455
"""
5556

5657
def validate_sort_items(sort_by: list[Sort]) -> tuple[Iterator[Sort], Iterator[Sort]]:
57-
return partition(lambda item: item.field in valid_sort_dict, sort_by)
58+
def _is_valid_sort(item: Sort) -> bool:
59+
return item.field in valid_sort_functions_by_column
60+
61+
return partition(_is_valid_sort, sort_by)
5862

5963
return validate_sort_items
6064

6165

62-
def generic_apply_sorts(
63-
valid_sort_dict: dict[str, str], model: SqlBaseModel
64-
) -> Callable[[QueryType, Iterator[Sort]], QueryType]:
65-
def _apply_sorts(query: QueryType, sort_by: Iterator[Sort]) -> QueryType:
66+
def generic_apply_sorting(
67+
valid_sort_functions_by_column: ValidSortFunctionsByColumnType,
68+
) -> Callable[[QueryType, Iterator[Sort], CallableErrorHandler], QueryType]:
69+
def _apply_sorting(query: QueryType, sort_by: Iterator[Sort], handle_sort_error: CallableErrorHandler) -> QueryType:
6670
for item in sort_by:
67-
field = item.field.lower()
68-
sort_key = valid_sort_dict[field]
69-
70-
if item.order == SortOrder.DESC.value: # type: ignore
71-
query = query.order_by(expression.desc(model.__dict__[sort_key]))
72-
else:
73-
query = query.order_by(expression.asc(model.__dict__[sort_key]))
71+
field = item.field
72+
sort_fn = valid_sort_functions_by_column[field]
73+
try:
74+
query = sort_fn(query, item.order)
75+
except ProblemDetailException as exception:
76+
handle_sort_error(
77+
exception.detail,
78+
field=field,
79+
order=item.order,
80+
)
81+
except ValueError as exception:
82+
handle_sort_error(
83+
str(exception),
84+
field=field,
85+
order=item.order,
86+
)
7487
return query
7588

76-
return _apply_sorts
89+
return _apply_sorting
7790

7891

7992
def generic_sort(
80-
valid_sort_dict: dict[str, str],
81-
model: SqlBaseModel,
82-
) -> Callable[[QueryType, list[Sort], CallableErrorHander], QueryType]:
83-
valid_sort_keys = list(valid_sort_dict.keys())
84-
_validate_sorts = generic_sorts_validate(valid_sort_dict)
85-
_apply_sorts = generic_apply_sorts(valid_sort_dict, model)
93+
valid_sort_functions_by_column: ValidSortFunctionsByColumnType,
94+
) -> Callable[[QueryType, list[Sort], CallableErrorHandler], QueryType]:
95+
valid_sort_functions_by_column_KEYS = list(valid_sort_functions_by_column.keys())
96+
_validate_sorts = generic_sorts_validate(valid_sort_functions_by_column)
97+
_apply_sorting = generic_apply_sorting(valid_sort_functions_by_column)
8698

8799
def _sort(
88100
query: QueryType,
89101
sort_by: list[Sort],
90-
handle_sort_error: CallableErrorHander,
102+
handle_sort_error: CallableErrorHandler,
91103
) -> QueryType:
92-
if sort_by:
93-
invalid_sort_items, valid_sort_items = _validate_sorts(sort_by)
94-
if invalid_list := [item.dict() for item in invalid_sort_items]:
95-
handle_sort_error(
96-
"Invalid sort arguments",
97-
invalid_filters=invalid_list,
98-
valid_filter_keys=valid_sort_keys,
99-
)
104+
invalid_sort_items, valid_sort_items = _validate_sorts(sort_by)
105+
if invalid_list := [{"field": item.field, "order": item.order.value.upper()} for item in invalid_sort_items]:
106+
handle_sort_error(
107+
"Invalid sort arguments",
108+
invalid_sorting=invalid_list,
109+
valid_sort_keys=valid_sort_functions_by_column_KEYS,
110+
)
100111

101-
query = _apply_sorts(query, valid_sort_items)
102-
return query
112+
return _apply_sorting(query, valid_sort_items, handle_sort_error)
103113

104114
return _sort
115+
116+
117+
def generic_column_sort(field: Column) -> Callable[[SearchQuery, SortOrder], SearchQuery]:
118+
def sort_function(query: SearchQuery, order: SortOrder) -> SearchQuery:
119+
if order == SortOrder.DESC:
120+
return query.order_by(expression.desc(field))
121+
return query.order_by(expression.asc(field))
122+
123+
return sort_function
Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1+
from sqlalchemy.inspection import inspect
2+
13
from orchestrator.db import SubscriptionTable
2-
from orchestrator.db.sorting.sorting import generic_sort
4+
from orchestrator.db.sorting.product import PRODUCT_SORT_FUNCTIONS_BY_COLUMN
5+
from orchestrator.db.sorting.sorting import generic_column_sort, generic_sort
6+
from orchestrator.utils.helpers import to_camel
7+
8+
subscription_table_sort = {
9+
to_camel(key): generic_column_sort(value) for [key, value] in inspect(SubscriptionTable).columns.items()
10+
}
311

4-
VALID_SORT_KEY_LIST = [
5-
"subscription_id",
6-
"product_id",
7-
"name",
8-
"description",
9-
"insync",
10-
"status",
11-
"note",
12-
"tag",
13-
"start_date",
14-
"end_date",
15-
]
16-
VALID_SORT_KEY_MAP = {key: key for key in VALID_SORT_KEY_LIST}
12+
SUBSCRIPTION_SORT_FUNCTIONS_BY_COLUMN = PRODUCT_SORT_FUNCTIONS_BY_COLUMN | subscription_table_sort
1713

18-
sort_subscriptions = generic_sort(VALID_SORT_KEY_MAP, SubscriptionTable)
14+
sort_subscriptions = generic_sort(SUBSCRIPTION_SORT_FUNCTIONS_BY_COLUMN)

orchestrator/graphql/resolvers/process.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sqlalchemy.orm import defer, joinedload
1919

2020
from orchestrator.db import ProcessSubscriptionTable, ProcessTable, SubscriptionTable
21-
from orchestrator.db.filters import CallableErrorHander, Filter
21+
from orchestrator.db.filters import CallableErrorHandler, Filter
2222
from orchestrator.db.filters.process import filter_processes
2323
from orchestrator.db.range import apply_range_to_query
2424
from orchestrator.db.sorting import Sort
@@ -38,7 +38,7 @@ def enrich_process(process: ProcessTable) -> ProcessGraphqlSchema:
3838
return ProcessGraphqlSchema(**data)
3939

4040

41-
def handle_process_error(info: CustomInfo) -> CallableErrorHander:
41+
def handle_process_error(info: CustomInfo) -> CallableErrorHandler:
4242
def _handle_process_error(message: str, **kwargs) -> None: # type: ignore
4343
logger.debug(message, **kwargs)
4444
extra_values = kwargs if kwargs else {}
@@ -55,7 +55,6 @@ async def resolve_processes(
5555
after: int = 0,
5656
) -> Connection[ProcessType]:
5757
_error_handler = handle_process_error(info)
58-
5958
pydantic_filter_by: list[Filter] = [item.to_pydantic() for item in filter_by] if filter_by else []
6059
pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
6160
logger.info("resolve_processes() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by)

orchestrator/graphql/resolvers/product.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import structlog
44
from graphql import GraphQLError
55

6-
from orchestrator.db.filters import CallableErrorHander, Filter
6+
from orchestrator.db.filters import CallableErrorHandler, Filter
77
from orchestrator.db.filters.product import filter_products
88
from orchestrator.db.models import ProductTable
99
from orchestrator.db.range.range import apply_range_to_query
@@ -16,7 +16,7 @@
1616
logger = structlog.get_logger(__name__)
1717

1818

19-
def handle_product_error(info: CustomInfo) -> CallableErrorHander:
19+
def handle_product_error(info: CustomInfo) -> CallableErrorHandler:
2020
def _handle_product_error(message: str, **kwargs) -> None: # type: ignore
2121
logger.debug(message, **kwargs)
2222
extra_values = kwargs if kwargs else {}

orchestrator/graphql/resolvers/subscription.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515

1616
import structlog
1717
from graphql import GraphQLError
18-
from sqlalchemy.orm import joinedload
1918

20-
from orchestrator.db import SubscriptionTable
21-
from orchestrator.db.filters import CallableErrorHander, Filter
19+
from orchestrator.db import ProductTable, SubscriptionTable
20+
from orchestrator.db.filters import CallableErrorHandler, Filter
2221
from orchestrator.db.filters.subscription import filter_subscriptions
2322
from orchestrator.db.range import apply_range_to_query
2423
from orchestrator.db.sorting import Sort, sort_subscriptions
@@ -29,7 +28,7 @@
2928
logger = structlog.get_logger(__name__)
3029

3130

32-
def handle_subscription_error(info: CustomInfo) -> CallableErrorHander:
31+
def handle_subscription_error(info: CustomInfo) -> CallableErrorHandler:
3332
def _handle_subscription_error(message: str, **kwargs) -> None: # type: ignore
3433
logger.debug(message, **kwargs)
3534
extra_values = kwargs if kwargs else {}
@@ -51,7 +50,7 @@ async def resolve_subscriptions(
5150
pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
5251
logger.info("resolve_subscription() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by)
5352

54-
query = SubscriptionTable.query.options(joinedload(SubscriptionTable.product))
53+
query = SubscriptionTable.query.join(ProductTable)
5554

5655
query = filter_subscriptions(query, pydantic_filter_by, _error_handler)
5756
query = sort_subscriptions(query, pydantic_sort_by, _error_handler)

0 commit comments

Comments
 (0)