Skip to content

Commit 2eb6727

Browse files
committed
feat: add supports for array expressions in query parameters
1 parent eb7d52e commit 2eb6727

File tree

2 files changed

+71
-31
lines changed

2 files changed

+71
-31
lines changed

influxdb_client/client/query_api.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import codecs
88
import csv
99
from datetime import datetime, timedelta
10-
from typing import List, Generator, Any
10+
from typing import List, Generator, Any, Union, Iterable
1111

1212
from influxdb_client import Dialect, IntegerLiteral, BooleanLiteral, FloatLiteral, DateTimeLiteral, StringLiteral, \
13-
VariableAssignment, Identifier, OptionStatement, File, DurationLiteral, Duration, UnaryExpression, \
13+
VariableAssignment, Identifier, OptionStatement, File, DurationLiteral, Duration, UnaryExpression, Expression, \
1414
ImportDeclaration, MemberAssignment, MemberExpression, ArrayExpression
1515
from influxdb_client import Query, QueryService
1616
from influxdb_client.client.flux_csv_parser import FluxCsvParser, FluxSerializationMode
@@ -203,35 +203,43 @@ def _params_to_extern_ast(params: dict) -> List['OptionStatement']:
203203

204204
statements = []
205205
for key, value in params.items():
206-
if value is None:
206+
expression = QueryApi._parm_to_extern_ast(value)
207+
if expression is None:
207208
continue
208209

209-
if isinstance(value, bool):
210-
literal = BooleanLiteral("BooleanLiteral", value)
211-
elif isinstance(value, int):
212-
literal = IntegerLiteral("IntegerLiteral", str(value))
213-
elif isinstance(value, float):
214-
literal = FloatLiteral("FloatLiteral", value)
215-
elif isinstance(value, datetime):
216-
value = get_date_helper().to_utc(value)
217-
literal = DateTimeLiteral("DateTimeLiteral", value.strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
218-
elif isinstance(value, timedelta):
219-
_micro_delta = int(value / timedelta(microseconds=1))
220-
if _micro_delta < 0:
221-
literal = UnaryExpression("UnaryExpression", argument=DurationLiteral("DurationLiteral", [
222-
Duration(magnitude=-_micro_delta, unit="us")]), operator="-")
223-
else:
224-
literal = DurationLiteral("DurationLiteral", [Duration(magnitude=_micro_delta, unit="us")])
225-
elif isinstance(value, str):
226-
literal = StringLiteral("StringLiteral", str(value))
227-
else:
228-
literal = value
229-
230210
statements.append(OptionStatement("OptionStatement",
231211
VariableAssignment("VariableAssignment", Identifier("Identifier", key),
232-
literal)))
212+
expression)))
233213
return statements
234214

215+
@staticmethod
216+
def _parm_to_extern_ast(value) -> Union[Expression, None]:
217+
if value is None:
218+
return None
219+
if isinstance(value, bool):
220+
return BooleanLiteral("BooleanLiteral", value)
221+
elif isinstance(value, int):
222+
return IntegerLiteral("IntegerLiteral", str(value))
223+
elif isinstance(value, float):
224+
return FloatLiteral("FloatLiteral", value)
225+
elif isinstance(value, datetime):
226+
value = get_date_helper().to_utc(value)
227+
return DateTimeLiteral("DateTimeLiteral", value.strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
228+
elif isinstance(value, timedelta):
229+
_micro_delta = int(value / timedelta(microseconds=1))
230+
if _micro_delta < 0:
231+
return UnaryExpression("UnaryExpression", argument=DurationLiteral("DurationLiteral", [
232+
Duration(magnitude=-_micro_delta, unit="us")]), operator="-")
233+
else:
234+
return DurationLiteral("DurationLiteral", [Duration(magnitude=_micro_delta, unit="us")])
235+
elif isinstance(value, str):
236+
return StringLiteral("StringLiteral", str(value))
237+
elif isinstance(value, Iterable):
238+
return ArrayExpression("ArrayExpression",
239+
elements=list(map(lambda it: QueryApi._parm_to_extern_ast(it), value)))
240+
else:
241+
return value
242+
235243
@staticmethod
236244
def _build_flux_ast(params: dict = None, profilers: List[str] = None):
237245

tests/test_QueryApi.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dateutil.tz import tzutc
66
from httpretty import httpretty
77

8-
from influxdb_client import QueryApi, DurationLiteral, Duration, CallExpression, Expression, UnaryExpression, \
8+
from influxdb_client import QueryApi, DurationLiteral, Duration, CallExpression, UnaryExpression, \
99
Identifier, InfluxDBClient
1010
from influxdb_client.client.query_api import QueryOptions
1111
from influxdb_client.client.util.date_utils import get_date_helper
@@ -47,7 +47,7 @@ def test_query_flux_table(self):
4747
val_count = 0
4848
for table in tables:
4949
for row in table:
50-
for cell in row.values:
50+
for _ in row.values:
5151
val_count += 1
5252

5353
print("Values count: ", val_count)
@@ -61,7 +61,7 @@ def test_query_flux_csv(self):
6161

6262
val_count = 0
6363
for row in csv_result:
64-
for cell in row:
64+
for _ in row:
6565
val_count += 1
6666

6767
print("Values count: ", val_count)
@@ -98,7 +98,7 @@ def test_query_ast(self):
9898

9999
val_count = 0
100100
for row in csv_result:
101-
for cell in row:
101+
for _ in row:
102102
val_count += 1
103103

104104
print("Values count: ", val_count)
@@ -263,7 +263,39 @@ def test_parameter_ast(self):
263263
}
264264
],
265265
"imports": []
266-
}]]
266+
}], ["arrayParam", ["bar1", "bar2", "bar3"],
267+
{
268+
"body": [
269+
{
270+
"assignment": {
271+
"id": {
272+
"name": "arrayParam",
273+
"type": "Identifier"
274+
},
275+
"init": {
276+
"elements": [
277+
{
278+
"type": "StringLiteral",
279+
"value": "bar1"
280+
},
281+
{
282+
"type": "StringLiteral",
283+
"value": "bar2"
284+
},
285+
{
286+
"type": "StringLiteral",
287+
"value": "bar3"
288+
}
289+
],
290+
"type": "ArrayExpression"
291+
},
292+
"type": "VariableAssignment"
293+
},
294+
"type": "OptionStatement"
295+
}
296+
],
297+
"imports": []
298+
}]]
267299

268300
for data in test_data:
269301
param = {data[0]: data[1]}
@@ -290,7 +322,7 @@ def test_query_profiler_enabled(self):
290322
for table in csv_result:
291323
self.assertFalse(any(filter(lambda column: (column.default_value == "_profiler"), table.columns)))
292324
for flux_record in table:
293-
self.assertFalse( flux_record["_measurement"].startswith("profiler/"))
325+
self.assertFalse(flux_record["_measurement"].startswith("profiler/"))
294326

295327
records = self.client.query_api().query_stream(query=q, params=p)
296328

0 commit comments

Comments
 (0)