Skip to content

Commit 4ecce15

Browse files
authored
feat: DataFrame optimalization (#97)
1 parent cd50aa3 commit 4ecce15

File tree

7 files changed

+135
-28
lines changed

7 files changed

+135
-28
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,4 @@ sandbox
114114
# OpenAPI-generator
115115
/.openapi-generator*
116116
/tests/writer.pickle
117+
/tests/data_frame_file.csv

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Features
66
1. [#79](https://github.com/influxdata/influxdb-client-python/issues/79): Added support for writing Pandas DataFrame
7+
2. [#92](https://github.com/influxdata/influxdb-client-python/issues/92): Optimize serializing Pandas DataFrame for writing
78

89
### Bug Fixes
910
1. [#85](https://github.com/influxdata/influxdb-client-python/issues/85): Fixed a possibility to generate empty write batch

extra-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
pandas>=0.25.3
2+
numpy

influxdb_client/client/write_api.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# coding: utf-8
22
import logging
33
import os
4+
import re
45
from datetime import timedelta
56
from enum import Enum
7+
from functools import reduce
8+
from itertools import chain
69
from random import random
710
from time import sleep
811
from typing import Union, List
@@ -14,7 +17,7 @@
1417

1518
from influxdb_client import WritePrecision, WriteService
1619
from influxdb_client.client.abstract_client import AbstractClient
17-
from influxdb_client.client.write.point import Point, DEFAULT_WRITE_PRECISION
20+
from influxdb_client.client.write.point import Point, DEFAULT_WRITE_PRECISION, _ESCAPE_KEY
1821
from influxdb_client.rest import ApiException
1922

2023
logger = logging.getLogger(__name__)
@@ -253,10 +256,8 @@ def _serialize(self, record, write_precision, **kwargs) -> bytes:
253256
_result = self._serialize(Point.from_dict(record, write_precision=write_precision),
254257
write_precision, **kwargs)
255258
elif 'DataFrame' in type(record).__name__:
256-
_result = self._serialize(self._data_frame_to_list_of_points(record,
257-
precision=write_precision, **kwargs),
258-
write_precision,
259-
**kwargs)
259+
_data = self._data_frame_to_list_of_points(record, precision=write_precision, **kwargs)
260+
_result = self._serialize(_data, write_precision, **kwargs)
260261

261262
elif isinstance(record, list):
262263
_result = b'\n'.join([self._serialize(item, write_precision,
@@ -297,8 +298,12 @@ def _write_batching(self, bucket, org, data,
297298

298299
return None
299300

301+
def _itertuples(self, data_frame):
302+
cols = [data_frame.iloc[:, k] for k in range(len(data_frame.columns))]
303+
return zip(data_frame.index, *cols)
304+
300305
def _data_frame_to_list_of_points(self, data_frame, precision, **kwargs):
301-
from ..extras import pd
306+
from ..extras import pd, np
302307
if not isinstance(data_frame, pd.DataFrame):
303308
raise TypeError('Must be DataFrame, but type was: {0}.'
304309
.format(type(data_frame)))
@@ -314,28 +319,35 @@ def _data_frame_to_list_of_points(self, data_frame, precision, **kwargs):
314319
if data_frame.index.tzinfo is None:
315320
data_frame.index = data_frame.index.tz_localize('UTC')
316321

317-
data = []
322+
measurement_name = kwargs.get('data_frame_measurement_name')
323+
data_frame_tag_columns = kwargs.get('data_frame_tag_columns')
324+
data_frame_tag_columns = set(data_frame_tag_columns or [])
318325

319-
for c, (row) in enumerate(data_frame.values):
320-
point = Point(measurement_name=kwargs.get('data_frame_measurement_name'))
326+
tags = []
327+
fields = []
321328

322-
for count, (value) in enumerate(row):
323-
column = data_frame.columns[count]
324-
data_frame_tag_columns = kwargs.get('data_frame_tag_columns')
325-
if data_frame_tag_columns and column in data_frame_tag_columns:
326-
point.tag(column, value)
327-
else:
328-
point.field(column, value)
329+
if self._point_settings.defaultTags:
330+
for key, value in self._point_settings.defaultTags.items():
331+
data_frame[key] = value
332+
data_frame_tag_columns.add(key)
329333

330-
point.time(data_frame.index[c], precision)
334+
for index, (key, value) in enumerate(data_frame.dtypes.items()):
335+
key = str(key).translate(_ESCAPE_KEY)
331336

332-
if self._point_settings.defaultTags:
333-
for key, val in self._point_settings.defaultTags.items():
334-
point.tag(key, val)
337+
if key in data_frame_tag_columns:
338+
tags.append(f"{key}={{p[{index + 1}].translate(_ESCAPE_KEY)}}")
339+
elif issubclass(value.type, np.integer):
340+
fields.append(f"{key}={{p[{index + 1}]}}i")
341+
elif issubclass(value.type, (np.float, np.bool_)):
342+
fields.append(f"{key}={{p[{index + 1}]}}")
343+
else:
344+
fields.append(f"{key}=\"{{p[{index + 1}].translate(_ESCAPE_KEY)}}\"")
335345

336-
data.append(point)
346+
fmt = (f'{measurement_name}', f'{"," if tags else ""}', ','.join(tags),
347+
' ', ','.join(fields), ' {p[0].value}')
348+
f = eval("lambda p: f'{}'".format(''.join(fmt)))
337349

338-
return data
350+
return list(map(f, self._itertuples(data_frame)))
339351

340352
def _http(self, batch_item: _BatchItem):
341353

influxdb_client/extras.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,9 @@
33
except ModuleNotFoundError as err:
44
raise ImportError(f"`query_data_frame` requires Pandas which couldn't be imported due: {err}")
55

6-
__all__ = ['pd']
6+
try:
7+
import numpy as np
8+
except ModuleNotFoundError as err:
9+
raise ImportError(f"`data_frame` requires numpy which couldn't be imported due: {err}")
10+
11+
__all__ = ['pd', 'np']

tests/test_WriteApi.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import datetime
66
import os
77
import unittest
8-
import time
98
from datetime import timedelta
109
from multiprocessing.pool import ApplyResult
1110

@@ -231,9 +230,9 @@ def test_write_data_frame(self):
231230
bucket = self.create_test_bucket()
232231

233232
now = pd.Timestamp('1970-01-01 00:00+00:00')
234-
data_frame = pd.DataFrame(data=[["coyote_creek", 1.0], ["coyote_creek", 2.0]],
233+
data_frame = pd.DataFrame(data=[["coyote_creek", 1], ["coyote_creek", 2]],
235234
index=[now + timedelta(hours=1), now + timedelta(hours=2)],
236-
columns=["location", "water_level"])
235+
columns=["location", "water water_level"])
237236

238237
self.write_client.write(bucket.name, record=data_frame, data_frame_measurement_name='h2o_feet',
239238
data_frame_tag_columns=['location'])
@@ -247,14 +246,14 @@ def test_write_data_frame(self):
247246
self.assertEqual(result[0].records[0].get_measurement(), "h2o_feet")
248247
self.assertEqual(result[0].records[0].get_value(), 1.0)
249248
self.assertEqual(result[0].records[0].values.get("location"), "coyote_creek")
250-
self.assertEqual(result[0].records[0].get_field(), "water_level")
249+
self.assertEqual(result[0].records[0].get_field(), "water water_level")
251250
self.assertEqual(result[0].records[0].get_time(),
252251
datetime.datetime(1970, 1, 1, 1, 0, tzinfo=datetime.timezone.utc))
253252

254253
self.assertEqual(result[0].records[1].get_measurement(), "h2o_feet")
255254
self.assertEqual(result[0].records[1].get_value(), 2.0)
256255
self.assertEqual(result[0].records[1].values.get("location"), "coyote_creek")
257-
self.assertEqual(result[0].records[1].get_field(), "water_level")
256+
self.assertEqual(result[0].records[1].get_field(), "water water_level")
258257
self.assertEqual(result[0].records[1].get_time(),
259258
datetime.datetime(1970, 1, 1, 2, 0, tzinfo=datetime.timezone.utc))
260259

tests/test_WriteApiDataFrame.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import csv
2+
import os
3+
import time
4+
import unittest
5+
from datetime import timedelta
6+
7+
from influxdb_client import InfluxDBClient, WriteOptions, WriteApi
8+
from influxdb_client.client.write_api import SYNCHRONOUS
9+
from tests.base_test import BaseTest
10+
11+
12+
class DataFrameWriteTest(BaseTest):
13+
14+
def setUp(self) -> None:
15+
super().setUp()
16+
self.influxDb_client = InfluxDBClient(url="http://localhost:9999", token="my-token", debug=False)
17+
18+
self.write_options = WriteOptions(batch_size=10_000, flush_interval=5_000, retry_interval=3_000)
19+
self._write_client = WriteApi(influxdb_client=self.influxDb_client, write_options=self.write_options)
20+
21+
def tearDown(self) -> None:
22+
super().tearDown()
23+
self._write_client.__del__()
24+
25+
@unittest.skip('Test big file')
26+
def test_write_data_frame(self):
27+
import random
28+
from influxdb_client.extras import pd
29+
30+
if not os.path.isfile("data_frame_file.csv"):
31+
with open('data_frame_file.csv', mode='w+') as csv_file:
32+
_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
33+
_writer.writerow(['time', 'col1', 'col2', 'col3', 'col4', 'col5', 'col6', 'col7', 'col8'])
34+
35+
for i in range(1, 1500000):
36+
choice = ['test_a', 'test_b', 'test_c']
37+
_writer.writerow([i, random.choice(choice), 'test', random.random(), random.random(),
38+
random.random(), random.random(), random.random(), random.random()])
39+
40+
csv_file.close()
41+
42+
with open('data_frame_file.csv', mode='rb') as csv_file:
43+
44+
data_frame = pd.read_csv(csv_file, index_col='time')
45+
print(data_frame)
46+
47+
print('Writing...')
48+
49+
start = time.time()
50+
51+
self._write_client.write("my-bucket", "my-org", record=data_frame,
52+
data_frame_measurement_name='h2o_feet',
53+
data_frame_tag_columns=['location'])
54+
55+
self._write_client.__del__()
56+
57+
print("Time elapsed: ", (time.time() - start))
58+
59+
csv_file.close()
60+
61+
def test_write_num_py(self):
62+
from influxdb_client.extras import pd, np
63+
64+
bucket = self.create_test_bucket()
65+
66+
now = pd.Timestamp('2020-04-05 00:00+00:00')
67+
68+
data_frame = pd.DataFrame(data=[["coyote_creek", np.int64(100.5)], ["coyote_creek", np.int64(200)]],
69+
index=[now + timedelta(hours=1), now + timedelta(hours=2)],
70+
columns=["location", "water_level"])
71+
72+
write_api = self.client.write_api(write_options=SYNCHRONOUS)
73+
write_api.write(bucket.name, record=data_frame, data_frame_measurement_name='h2o_feet',
74+
data_frame_tag_columns=['location'])
75+
76+
write_api.__del__()
77+
78+
result = self.query_api.query(
79+
"from(bucket:\"" + bucket.name + "\") |> range(start: 1970-01-01T00:00:00.000000001Z)",
80+
self.my_organization.id)
81+
82+
self.assertEqual(1, len(result))
83+
self.assertEqual(2, len(result[0].records))
84+
85+
self.assertEqual(result[0].records[0].get_value(), 100.0)
86+
self.assertEqual(result[0].records[1].get_value(), 200.0)
87+
88+
pass

0 commit comments

Comments
 (0)