Skip to content

Commit f4be9f4

Browse files
committed
chore: influxdb_client/client/write: fix data_frame_to_list_of_points
Fix the possibility of data corruption by using a much simpler regular expression to fix up the results.
1 parent 91dcafb commit f4be9f4

File tree

4 files changed

+143
-96
lines changed

4 files changed

+143
-96
lines changed

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,3 @@ sandbox
114114
# OpenAPI-generator
115115
/.openapi-generator*
116116
**/writer.pickle
117-
/tests/data_frame_file.csv

influxdb_client/client/write/dataframe_serializer.py

+97-57
Original file line numberDiff line numberDiff line change
@@ -5,115 +5,155 @@
55
"""
66

77
import re
8-
from functools import reduce
9-
from itertools import chain
8+
import math
109

1110
from influxdb_client.client.write.point import _ESCAPE_KEY, _ESCAPE_STRING, _ESCAPE_MEASUREMENT
1211

1312

14-
def _replace(data_frame):
15-
from ...extras import np
16-
17-
# string columns
18-
obj_cols = {k for k, v in dict(data_frame.dtypes).items() if v is np.dtype('O')}
19-
20-
# number columns
21-
other_cols = set(data_frame.columns) - obj_cols
22-
23-
obj_nans = (f'{k}=nan' for k in obj_cols)
24-
other_nans = (f'{k}=nani?' for k in other_cols)
25-
26-
replacements = [
27-
('|'.join(chain(obj_nans, other_nans)), ''),
28-
(',{2,}', ','),
29-
('|'.join([', ,', ', ', ' ,']), ' '),
30-
]
31-
32-
return replacements
33-
34-
3513
def _itertuples(data_frame):
3614
cols = [data_frame.iloc[:, k] for k in range(len(data_frame.columns))]
3715
return zip(data_frame.index, *cols)
3816

3917

40-
def _is_nan(x):
41-
return x != x
18+
def _not_nan(x):
19+
return x == x
4220

4321

4422
def _any_not_nan(p, indexes):
45-
return any(map(lambda inx: not _is_nan(p[inx]), indexes))
23+
return any(map(lambda x: _not_nan(p[x]), indexes))
4624

4725

4826
def data_frame_to_list_of_points(data_frame, point_settings, **kwargs):
4927
"""Serialize DataFrame into LineProtocols."""
28+
# This function is hard to understand but for good reason:
29+
# the approach used here is considerably more efficient
30+
# than the alternatives,
31+
#
32+
# We build up a Python expression that very efficiently converts a data point
33+
# tuple into line-protocol entry, and then evaluate the expression
34+
# as a lambda so that we can call it. This avoids the overhead of
35+
# invoking a function on every data value - we only have one function
36+
# call per row instead.
37+
5038
from ...extras import pd, np
5139
if not isinstance(data_frame, pd.DataFrame):
5240
raise TypeError('Must be DataFrame, but type was: {0}.'
5341
.format(type(data_frame)))
5442

55-
if 'data_frame_measurement_name' not in kwargs:
43+
data_frame_measurement_name = kwargs.get('data_frame_measurement_name')
44+
if data_frame_measurement_name is None:
5645
raise TypeError('"data_frame_measurement_name" is a Required Argument')
5746

47+
data_frame = data_frame.copy(deep=False)
5848
if isinstance(data_frame.index, pd.PeriodIndex):
5949
data_frame.index = data_frame.index.to_timestamp()
6050
else:
51+
# TODO: this is almost certainly not what you want
52+
# when the index is the default RangeIndex.
53+
# Instead, it would probably be better to leave
54+
# out the timestamp unless a time column is explicitly
55+
# enabled.
6156
data_frame.index = pd.to_datetime(data_frame.index)
6257

6358
if data_frame.index.tzinfo is None:
6459
data_frame.index = data_frame.index.tz_localize('UTC')
6560

66-
measurement_name = str(kwargs.get('data_frame_measurement_name')).translate(_ESCAPE_MEASUREMENT)
6761
data_frame_tag_columns = kwargs.get('data_frame_tag_columns')
6862
data_frame_tag_columns = set(data_frame_tag_columns or [])
6963

7064
tags = []
71-
fields = []
72-
fields_indexes = []
7365
keys = []
66+
fields = []
67+
field_indexes = []
7468

7569
if point_settings.defaultTags:
7670
for key, value in point_settings.defaultTags.items():
71+
# TODO: this overrides any values for the column
72+
# which is probably not what a "default" tag value
73+
# is meant to do. It might be better to add the
74+
# column only when it doesn't already exist,
75+
# and to fill out any NaN values with the default
76+
# value otherwise.
7777
data_frame[key] = value
7878
data_frame_tag_columns.add(key)
7979

80-
for index, (key, value) in enumerate(data_frame.dtypes.items()):
80+
# Get a list of all the columns sorted by field/tag key.
81+
columns = sorted(enumerate(data_frame.dtypes.items()), key=lambda col: col[1][0])
82+
83+
null_columns = data_frame.isnull().any()
84+
for index, (key, value) in columns:
8185
key = str(key)
86+
key_format = f'{{keys[{len(keys)}]}}'
8287
keys.append(key.translate(_ESCAPE_KEY))
83-
key_format = f'{{keys[{index}]}}'
88+
# The field index is one more than the column index because the
89+
# time index is at column zero in the finally zipped-together
90+
# result columns.
91+
field_index = index + 1
92+
val_format = f'p[{field_index}]'
8493

85-
index_value = index + 1
8694
if key in data_frame_tag_columns:
87-
tags.append({'key': key, 'value': f"{key_format}={{str(p[{index_value}]).translate(_ESCAPE_KEY)}}"})
88-
elif issubclass(value.type, np.integer):
89-
fields.append(f"{key_format}={{p[{index_value}]}}i")
90-
fields_indexes.append(index_value)
91-
elif issubclass(value.type, (np.float, np.bool_)):
92-
fields.append(f"{key_format}={{p[{index_value}]}}")
93-
fields_indexes.append(index_value)
95+
if null_columns[index]:
96+
key_value = f"""{{
97+
'' if {val_format} == '' or type({val_format}) == float and math.isnan({val_format}) else
98+
f',{key_format}={{str({val_format}).translate(_ESCAPE_STRING)}}'
99+
}}"""
100+
else:
101+
key_value = f',{key_format}={{str({val_format}).translate(_ESCAPE_KEY)}}'
102+
tags.append(key_value)
103+
continue
104+
# Note: no comma separator needed for the first field.
105+
# It's important to omit it because when the first
106+
# field column has no nulls, we don't run the comma-removal
107+
# regexp substitution step.
108+
sep = '' if len(field_indexes) == 0 else ','
109+
if issubclass(value.type, np.integer):
110+
field_value = f"{sep}{key_format}={{{val_format}}}i"
111+
elif issubclass(value.type, np.bool_):
112+
field_value = f'{sep}{key_format}={{{val_format}}}'
113+
elif issubclass(value.type, np.float):
114+
if null_columns[index]:
115+
field_value = f"""{{"" if math.isnan({val_format}) else f"{sep}{key_format}={{{val_format}}}"}}"""
116+
else:
117+
field_value = f'{sep}{key_format}={{{val_format}}}'
94118
else:
95-
fields.append(f"{key_format}=\"{{str(p[{index_value}]).translate(_ESCAPE_STRING)}}\"")
96-
fields_indexes.append(index_value)
97-
98-
tags.sort(key=lambda x: x['key'])
99-
tags = ','.join(map(lambda y: y['value'], tags))
100-
101-
fmt = ('{measurement_name}', f'{"," if tags else ""}', tags,
102-
' ', ','.join(fields), ' {p[0].value}')
103-
f = eval("lambda p: f'{}'".format(''.join(fmt)),
104-
{'measurement_name': measurement_name, '_ESCAPE_KEY': _ESCAPE_KEY, '_ESCAPE_STRING': _ESCAPE_STRING,
105-
'keys': keys})
119+
if null_columns[index]:
120+
field_value = f"""{{
121+
'' if type({val_format}) == float64 and math.isnan({val_format}) else
122+
f'{sep}{key_format}="{{str({val_format}).translate(_ESCAPE_STRING)}}"'
123+
}}"""
124+
else:
125+
field_value = f'''{sep}{key_format}="{{str({val_format}).translate(_ESCAPE_STRING)}}"'''
126+
field_indexes.append(field_index)
127+
fields.append(field_value)
128+
129+
measurement_name = str(data_frame_measurement_name).translate(_ESCAPE_MEASUREMENT)
130+
131+
tags = ''.join(tags)
132+
fields = ''.join(fields)
133+
timestamp = '{p[0].value}'
134+
135+
print(f'measurement_name: {measurement_name}')
136+
print(f'keys: {keys}')
137+
print(f'tag columns: {data_frame_tag_columns}')
138+
print(f'lambda p: f"""{{measurement_name}}{tags} {fields} {timestamp}"""')
139+
f = eval(f'lambda p: f"""{{measurement_name}}{tags} {fields} {timestamp}"""', {
140+
'measurement_name': measurement_name,
141+
'_ESCAPE_KEY': _ESCAPE_KEY,
142+
'_ESCAPE_STRING': _ESCAPE_STRING,
143+
'keys': keys,
144+
'math': math,
145+
})
106146

107147
for k, v in dict(data_frame.dtypes).items():
108148
if k in data_frame_tag_columns:
109149
data_frame[k].replace('', np.nan, inplace=True)
110150

111-
isnull = data_frame.isnull().any(axis=1)
112-
113-
if isnull.any():
114-
rep = _replace(data_frame)
115-
lp = (reduce(lambda a, b: re.sub(*b, a), rep, f(p))
116-
for p in filter(lambda x: _any_not_nan(x, fields_indexes), _itertuples(data_frame)))
151+
first_field_maybe_null = null_columns[field_indexes[0] - 1]
152+
if first_field_maybe_null:
153+
# When the first field is null (None/NaN), we'll have
154+
# a spurious leading comma which needs to be removed.
155+
lp = (re.sub('^((\\ |[^ ])* ),', '\\1', f(p))
156+
for p in filter(lambda x: _any_not_nan(x, field_indexes), _itertuples(data_frame)))
117157
return list(lp)
118158
else:
119159
return list(map(f, _itertuples(data_frame)))

influxdb_client/client/write/point.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,32 @@
1414
from influxdb_client.domain.write_precision import WritePrecision
1515

1616
EPOCH = UTC.localize(datetime.utcfromtimestamp(0))
17+
1718
DEFAULT_WRITE_PRECISION = WritePrecision.NS
18-
_ESCAPE_MEASUREMENT = str.maketrans({'\\': '\\\\', ',': r'\,', ' ': r'\ ', '\n': '\\n', '\t': '\\t', '\r': '\\r'})
19-
_ESCAPE_KEY = str.maketrans({'\\': '\\\\', ',': r'\,', ' ': r'\ ', '=': r'\=', '\n': '\\n', '\t': '\\t', '\r': '\\r'})
20-
_ESCAPE_STRING = str.maketrans({'\"': r"\"", "\\": r"\\"})
19+
20+
_ESCAPE_MEASUREMENT = str.maketrans({
21+
'\\': r'\\', # Note: this is wrong. Backslashes are not escaped like this in measurements.
22+
',': r'\,',
23+
' ': r'\ ',
24+
'\n': r'\n',
25+
'\t': r'\t',
26+
'\r': r'\r',
27+
})
28+
29+
_ESCAPE_KEY = str.maketrans({
30+
'\\': r'\\', # Note: this is wrong. Backslashes are not escaped like this in keys.
31+
',': r'\,',
32+
'=': r'\=',
33+
' ': r'\ ',
34+
'\n': r'\n',
35+
'\t': r'\t',
36+
'\r': r'\r',
37+
})
38+
39+
_ESCAPE_STRING = str.maketrans({
40+
'"': r'\"',
41+
'\\': r'\\',
42+
})
2143

2244

2345
class Point(object):

tests/test_WriteApiDataFrame.py

+21-35
Original file line numberDiff line numberDiff line change
@@ -23,41 +23,27 @@ def tearDown(self) -> None:
2323
super().tearDown()
2424
self._write_client.__del__()
2525

26-
@unittest.skip('Test big file')
27-
def test_write_data_frame(self):
28-
import random
26+
@unittest.skip('Test big data')
27+
def test_convert_data_frame(self):
2928
from influxdb_client.extras import pd
3029

31-
if not os.path.isfile("data_frame_file.csv"):
32-
with open('data_frame_file.csv', mode='w+') as csv_file:
33-
_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
34-
_writer.writerow(['time', 'col1', 'col2', 'col3', 'col4', 'col5', 'col6', 'col7', 'col8'])
30+
num_rows=1500000
31+
col_data={
32+
'time': np.arange(0, num_rows, 1, dtype=int),
33+
'col1': np.random.choice(['test_a', 'test_b', 'test_c'], size=(num_rows,)),
34+
}
35+
for n in range(2, 9):
36+
col_data[f'col{n}'] = np.random.rand(num_rows)
3537

36-
for i in range(1, 1500000):
37-
choice = ['test_a', 'test_b', 'test_c']
38-
_writer.writerow([i, random.choice(choice), 'test', random.random(), random.random(),
39-
random.random(), random.random(), random.random(), random.random()])
38+
data_frame = pd.DataFrame(data=col_data)
39+
print(data_frame)
4040

41-
csv_file.close()
41+
start = time.time()
42+
data_frame_to_list_of_points(record, PointSettings(),
43+
data_frame_measurement_name='h2o_feet',
44+
data_frame_tag_columns=['location'])
4245

43-
with open('data_frame_file.csv', mode='rb') as csv_file:
44-
45-
data_frame = pd.read_csv(csv_file, index_col='time')
46-
print(data_frame)
47-
48-
print('Writing...')
49-
50-
start = time.time()
51-
52-
self._write_client.write("my-bucket", "my-org", record=data_frame,
53-
data_frame_measurement_name='h2o_feet',
54-
data_frame_tag_columns=['location'])
55-
56-
self._write_client.__del__()
57-
58-
print("Time elapsed: ", (time.time() - start))
59-
60-
csv_file.close()
46+
print("Time elapsed: ", (time.time() - start))
6147

6248
def test_write_num_py(self):
6349
from influxdb_client.extras import pd, np
@@ -110,14 +96,14 @@ def test_write_nan(self):
11096
data_frame_measurement_name='measurement')
11197

11298
self.assertEqual(4, len(points))
113-
self.assertEqual("measurement actual_kw_price=3.1955,actual_general_use=20.514305 1586044800000000000",
99+
self.assertEqual("measurement actual_general_use=20.514305,actual_kw_price=3.1955 1586044800000000000",
114100
points[0])
115-
self.assertEqual("measurement actual_kw_price=5.731,actual_general_use=23.32871 1586046600000000000",
101+
self.assertEqual("measurement actual_general_use=23.32871,actual_kw_price=5.731 1586046600000000000",
116102
points[1])
117-
self.assertEqual("measurement forecast_kw_price=3.138664,forecast_general_use=20.755026 1586048400000000000",
103+
self.assertEqual("measurement forecast_general_use=20.755026,forecast_kw_price=3.138664 1586048400000000000",
118104
points[2])
119-
self.assertEqual("measurement actual_kw_price=5.731,forecast_kw_price=5.139563,actual_general_use=23.32871,"
120-
"forecast_general_use=19.79124 1586050200000000000",
105+
self.assertEqual("measurement actual_general_use=23.32871,actual_kw_price=5.731,forecast_general_use=19.79124"
106+
",forecast_kw_price=5.139563 1586050200000000000",
121107
points[3])
122108

123109
def test_write_tag_nan(self):

0 commit comments

Comments
 (0)