diff --git a/src/server/_query.py b/src/server/_query.py index 1029c5e2c..bd3905a76 100644 --- a/src/server/_query.py +++ b/src/server/_query.py @@ -10,8 +10,9 @@ Tuple, Union, cast, - Mapping, -) + Mapping, + Type, + ) from sqlalchemy import text from sqlalchemy.engine import Row @@ -65,6 +66,20 @@ def filter_values( conditions = [to_condition(field, v, f"{param_key}_{i}", params, formatter) for i, v in enumerate(values)] return f"({' OR '.join(conditions)})" +def alternative_filter_values( + field: str, + values: Optional[Sequence[Union[Tuple[str, str], str, Tuple[int, int], int]]], + param_key: str, + params: Dict[str, Any], + formatter=lambda x: x, +): + if not values: + return "FALSE" + # builds a SQL expression to filter strings (ex: locations) + # $field: name of the field to filter + # $values: array of values + conditions = [to_condition(field, v, f"{param_key}_{i}", params, formatter) for i, v in enumerate(values)] + return conditions def filter_strings( field: str, @@ -74,6 +89,14 @@ def filter_strings( ): return filter_values(field, values, param_key, params) +def alternative_filter_strings( + field: str, + values: Optional[Sequence[Union[Tuple[str, str], str]]], + param_key: str, + params: Dict[str, Any], +): + return alternative_filter_values(field, values, param_key, params) + def filter_integers( field: str, @@ -150,25 +173,37 @@ def filter_source_signal_pairs( values: Sequence[SourceSignalPair], param_key: str, params: Dict[str, Any], -) -> str: +) -> Union[str, list]: """ - returns the SQL sub query to filter by the given source signal pairs + returns an array of SQL sub queries to filter by the given source signal pairs """ - def filter_pair(pair: SourceSignalPair, i) -> str: + def filter_pair(pair: SourceSignalPair, i) -> Union[str, list]: source_param = f"{param_key}_{i}t" params[source_param] = pair.source if isinstance(pair.signal, bool) and pair.signal: return f"{source_field} = :{source_param}" - return f"({source_field} = :{source_param} AND {filter_strings(signal_field, cast(Sequence[str], pair.signal), source_param, params)})" - - parts = [filter_pair(p, i) for i, p in enumerate(values)] + conditions = alternative_filter_strings(signal_field, cast(Sequence[str], pair.signal), source_param, params) + condition_array = [] + if conditions: + for condition in conditions: + condition_array.append(f"({source_field} = :{source_param} AND {condition})") + return condition_array + + parts = [] + for i, p in enumerate(values): + array_or_str = filter_pair(p, i) + if isinstance(array_or_str, str): + parts.append(array_or_str) + else: + for x in array_or_str: + parts.append(x) if not parts: # something has to be selected return "FALSE" - return f"({' OR '.join(parts)})" + return parts def filter_time_pairs( @@ -344,6 +379,7 @@ def __init__(self, table: str, alias: str): self.order: Union[str, List[str]] = "" self.fields: Union[str, List[str]] = "*" self.conditions: List[str] = [] + self.signal_array: List[str] = [] self.params: Dict[str, Any] = {} self.subquery: str = "" self.index: Optional[str] = None @@ -375,7 +411,18 @@ def __str__(self): group_by = f"GROUP BY {_join_l(self.group_by)}" if self.group_by else "" index = f"USE INDEX ({self.index})" if self.index else "" - return f"SELECT {self.fields_clause} FROM {self.table} {index} {self.subquery} {where} {group_by} {order}" + # if no signal array, assemble the sql and return + if not self.signal_array: + command = f"SELECT {self.fields_clause} FROM {self.table} {index} {self.subquery} {where}" + else: + # if there is a signal array, concatenate signals with UNION ALL + command = " UNION ALL ".join( + f"SELECT {self.fields_clause} FROM {self.table} {index} {self.subquery} {where} AND {signal_clause}" + for signal_clause in self.signal_array + ) + + command += f" {group_by} {order}" + return command @property def query(self) -> str: @@ -452,7 +499,7 @@ def where_source_signal_pairs( ) -> "QueryBuilder": fq_type_field = self._fq_field(type_field) fq_value_field = self._fq_field(value_field) - self.conditions.append( + self.signal_array.extend( filter_source_signal_pairs( fq_type_field, fq_value_field, @@ -499,8 +546,9 @@ def to_asc(v: Union[str, bool]) -> str: return "DESC" return cast(str, v) - args_order = [f"{self.alias}.{k} ASC" for k in args] - kw_order = [f"{self.alias}.{k} {to_asc(v)}" for k, v in kwargs.items()] + # Use the column name without their table name for the Order By clause since Union All is used + args_order = [f"`{k}` ASC" for k in args] + kw_order = [f"{k} {to_asc(v)}" for k, v in kwargs.items()] self.order = args_order + kw_order return self diff --git a/tests/server/test_query.py b/tests/server/test_query.py index a59030b75..59113a6dc 100644 --- a/tests/server/test_query.py +++ b/tests/server/test_query.py @@ -195,21 +195,21 @@ def test_filter_source_signal_pairs(self): params = {} self.assertEqual( filter_source_signal_pairs("t", "v", [SourceSignalPair("src1", True)], "p", params), - "(t = :p_0t)", + ["t = :p_0t"] ) self.assertEqual(params, {"p_0t": "src1"}) with self.subTest("single"): params = {} self.assertEqual( filter_source_signal_pairs("t", "v", [SourceSignalPair("src1", ["sig1"])], "p", params), - "((t = :p_0t AND (v = :p_0t_0)))", + ["(t = :p_0t AND v = :p_0t_0)"] ) self.assertEqual(params, {"p_0t": "src1", "p_0t_0": "sig1"}) with self.subTest("multi"): params = {} self.assertEqual( filter_source_signal_pairs("t", "v", [SourceSignalPair("src1", ["sig1", "sig2"])], "p", params), - "((t = :p_0t AND (v = :p_0t_0 OR v = :p_0t_1)))", + ["(t = :p_0t AND v = :p_0t_0)", "(t = :p_0t AND v = :p_0t_1)"] ) self.assertEqual(params, {"p_0t": "src1", "p_0t_0": "sig1", "p_0t_1": "sig2"}) with self.subTest("multiple pairs"): @@ -222,7 +222,7 @@ def test_filter_source_signal_pairs(self): "p", params, ), - "(t = :p_0t OR t = :p_1t)", + ["t = :p_0t", "t = :p_1t"] ) self.assertEqual(params, {"p_0t": "src1", "p_1t": "src2"}) with self.subTest("multiple pairs with value"): @@ -238,7 +238,7 @@ def test_filter_source_signal_pairs(self): "p", params, ), - "((t = :p_0t AND (v = :p_0t_0)) OR (t = :p_1t AND (v = :p_1t_0)))", + ["(t = :p_0t AND v = :p_0t_0)", "(t = :p_1t AND v = :p_1t_0)"] ) self.assertEqual( params,