Skip to content

Commit 212cee3

Browse files
committed
feat: support for writing pandas DataFrame (#79)
1 parent ab3915f commit 212cee3

File tree

2 files changed

+118
-14
lines changed

2 files changed

+118
-14
lines changed

influxdb_client/client/write_api.py

+82-14
Original file line numberDiff line numberDiff line change
@@ -183,21 +183,24 @@ def __init__(self, influxdb_client, write_options: WriteOptions = WriteOptions()
183183
def write(self, bucket: str, org: str = None,
184184
record: Union[
185185
str, List['str'], Point, List['Point'], dict, List['dict'], bytes, List['bytes'], Observable] = None,
186-
write_precision: WritePrecision = DEFAULT_WRITE_PRECISION) -> None:
186+
write_precision: WritePrecision = DEFAULT_WRITE_PRECISION, data_frame_measurement_name: str = None,
187+
data_frame_tag_columns: List['str'] = None) -> None:
187188
"""
188189
Writes time-series data into influxdb.
189190
190191
:param str org: specifies the destination organization for writes; take either the ID or Name interchangeably; if both orgID and org are specified, org takes precedence. (required)
191192
:param str bucket: specifies the destination bucket for writes (required)
192193
:param WritePrecision write_precision: specifies the precision for the unix timestamps within the body line-protocol
193-
:param record: Points, line protocol, RxPY Observable to write
194+
:param record: Points, line protocol, Pandas DataFrame, RxPY Observable to write
195+
:param data_frame_measurement_name: name of measurement for writing Pandas DataFrame
196+
:param data_frame_tag_columns: list of DataFrame columns which are tags, rest columns will be fields
194197
195198
"""
196199

197200
if org is None:
198201
org = self._influxdb_client.org
199202

200-
if self._point_settings.defaultTags and record:
203+
if self._point_settings.defaultTags and record is not None:
201204
for key, val in self._point_settings.defaultTags.items():
202205
if isinstance(record, dict):
203206
record.get("tags")[key] = val
@@ -211,7 +214,9 @@ def write(self, bucket: str, org: str = None,
211214
if self._write_options.write_type is WriteType.batching:
212215
return self._write_batching(bucket, org, record, write_precision)
213216

214-
final_string = self._serialize(record, write_precision)
217+
final_string = self._serialize(record, write_precision,
218+
data_frame_measurement_name,
219+
data_frame_tag_columns)
215220

216221
_async_req = True if self._write_options.write_type == WriteType.asynchronous else False
217222

@@ -235,7 +240,7 @@ def __del__(self):
235240
self._disposable = None
236241
pass
237242

238-
def _serialize(self, record, write_precision) -> bytes:
243+
def _serialize(self, record, write_precision, data_frame_measurement_name, data_frame_tag_columns) -> bytes:
239244
_result = b''
240245
if isinstance(record, bytes):
241246
_result = record
@@ -244,40 +249,103 @@ def _serialize(self, record, write_precision) -> bytes:
244249
_result = record.encode("utf-8")
245250

246251
elif isinstance(record, Point):
247-
_result = self._serialize(record.to_line_protocol(), write_precision=write_precision)
252+
_result = self._serialize(record.to_line_protocol(), write_precision,
253+
data_frame_measurement_name, data_frame_tag_columns)
248254

249255
elif isinstance(record, dict):
250256
_result = self._serialize(Point.from_dict(record, write_precision=write_precision),
251-
write_precision=write_precision)
257+
write_precision,
258+
data_frame_measurement_name, data_frame_tag_columns)
259+
elif 'DataFrame' in type(record).__name__:
260+
_result = self._serialize(self._data_frame_to_list_of_points(record, data_frame_measurement_name,
261+
data_frame_tag_columns,
262+
precision=write_precision),
263+
write_precision,
264+
data_frame_measurement_name, data_frame_tag_columns)
265+
252266
elif isinstance(record, list):
253-
_result = b'\n'.join([self._serialize(item, write_precision=write_precision) for item in record])
267+
_result = b'\n'.join([self._serialize(item, write_precision,
268+
data_frame_measurement_name, data_frame_tag_columns) for item in record])
254269

255270
return _result
256271

257-
def _write_batching(self, bucket, org, data, precision=DEFAULT_WRITE_PRECISION):
272+
def _write_batching(self, bucket, org, data,
273+
data_frame_measurement_name, data_frame_tag_columns,
274+
precision=DEFAULT_WRITE_PRECISION):
258275
_key = _BatchItemKey(bucket, org, precision)
259276
if isinstance(data, bytes):
260277
self._subject.on_next(_BatchItem(key=_key, data=data))
261278

262279
elif isinstance(data, str):
263-
self._write_batching(bucket, org, data.encode("utf-8"), precision)
280+
self._write_batching(bucket, org, data.encode("utf-8"),
281+
data_frame_measurement_name, data_frame_tag_columns, precision)
264282

265283
elif isinstance(data, Point):
266-
self._write_batching(bucket, org, data.to_line_protocol(), precision)
284+
self._write_batching(bucket, org, data.to_line_protocol(),
285+
data_frame_measurement_name, data_frame_tag_columns, precision)
267286

268287
elif isinstance(data, dict):
269-
self._write_batching(bucket, org, Point.from_dict(data, write_precision=precision), precision)
288+
self._write_batching(bucket, org, Point.from_dict(data, write_precision=precision),
289+
data_frame_measurement_name, data_frame_tag_columns, precision)
290+
291+
elif 'DataFrame' in type(data).__name__:
292+
self._write_batching(bucket, org, self._data_frame_to_list_of_points(data, data_frame_measurement_name,
293+
data_frame_tag_columns, precision),
294+
data_frame_measurement_name, data_frame_tag_columns, precision)
270295

271296
elif isinstance(data, list):
272297
for item in data:
273-
self._write_batching(bucket, org, item, precision)
298+
self._write_batching(bucket, org, item,
299+
data_frame_measurement_name, data_frame_tag_columns, precision)
274300

275301
elif isinstance(data, Observable):
276-
data.subscribe(lambda it: self._write_batching(bucket, org, it, precision))
302+
data.subscribe(lambda it: self._write_batching(bucket, org, it,
303+
data_frame_measurement_name, data_frame_tag_columns,
304+
precision))
277305
pass
278306

279307
return None
280308

309+
def _data_frame_to_list_of_points(self, dataframe, data_frame_measurement_name, data_frame_tag_columns, precision='s'):
310+
from ..extras import pd
311+
if not isinstance(dataframe, pd.DataFrame):
312+
raise TypeError('Must be DataFrame, but type was: {0}.'
313+
.format(type(dataframe)))
314+
if not (isinstance(dataframe.index, pd.PeriodIndex) or
315+
isinstance(dataframe.index, pd.DatetimeIndex)):
316+
raise TypeError('Must be DataFrame with DatetimeIndex or \
317+
PeriodIndex.')
318+
319+
if isinstance(dataframe.index, pd.PeriodIndex):
320+
dataframe.index = dataframe.index.to_timestamp()
321+
else:
322+
dataframe.index = pd.to_datetime(dataframe.index)
323+
324+
if dataframe.index.tzinfo is None:
325+
dataframe.index = dataframe.index.tz_localize('UTC')
326+
327+
data = []
328+
329+
c = 0
330+
for v in dataframe.values:
331+
point = Point(measurement_name=data_frame_measurement_name)
332+
333+
count = 0
334+
for f in v:
335+
column = dataframe.columns[count]
336+
if data_frame_tag_columns and column in data_frame_tag_columns:
337+
point.tag(column, f)
338+
else:
339+
point.field(column, f)
340+
count += 1
341+
342+
point.time(dataframe.index[c], precision)
343+
c += 1
344+
345+
data.append(point)
346+
347+
return data
348+
281349
def _http(self, batch_item: _BatchItem):
282350

283351
logger.debug("Write time series data into InfluxDB: %s", batch_item)

tests/test_WriteApi.py

+36
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import unittest
88
import time
9+
from datetime import timedelta
910
from multiprocessing.pool import ApplyResult
1011

1112
from influxdb_client import Point, WritePrecision, InfluxDBClient
@@ -224,6 +225,41 @@ def test_write_bytes(self):
224225

225226
self.delete_test_bucket(_bucket)
226227

228+
def test_write_data_frame(self):
229+
from influxdb_client.extras import pd
230+
231+
bucket = self.create_test_bucket()
232+
233+
now = pd.Timestamp('1970-01-01 00:00+00:00')
234+
data_frame = pd.DataFrame(data=[["coyote_creek", 1.0], ["coyote_creek", 2.0]],
235+
index=[now + timedelta(hours=1), now + timedelta(hours=2)],
236+
columns=["location", "water_level"])
237+
238+
self.write_client.write(bucket.name, record=data_frame, data_frame_measurement_name='h2o_feet',
239+
data_frame_tag_columns=['location'])
240+
241+
result = self.query_api.query(
242+
"from(bucket:\"" + bucket.name + "\") |> range(start: 1970-01-01T00:00:00.000000001Z)", self.org)
243+
244+
self.assertEqual(1, len(result))
245+
self.assertEqual(2, len(result[0].records))
246+
247+
self.assertEqual(result[0].records[0].get_measurement(), "h2o_feet")
248+
self.assertEqual(result[0].records[0].get_value(), 1.0)
249+
self.assertEqual(result[0].records[0].values.get("location"), "coyote_creek")
250+
self.assertEqual(result[0].records[0].get_field(), "water_level")
251+
self.assertEqual(result[0].records[0].get_time(),
252+
datetime.datetime(1970, 1, 1, 1, 0, tzinfo=datetime.timezone.utc))
253+
254+
self.assertEqual(result[0].records[1].get_measurement(), "h2o_feet")
255+
self.assertEqual(result[0].records[1].get_value(), 2.0)
256+
self.assertEqual(result[0].records[1].values.get("location"), "coyote_creek")
257+
self.assertEqual(result[0].records[1].get_field(), "water_level")
258+
self.assertEqual(result[0].records[1].get_time(),
259+
datetime.datetime(1970, 1, 1, 2, 0, tzinfo=datetime.timezone.utc))
260+
261+
self.delete_test_bucket(bucket)
262+
227263
def test_use_default_org(self):
228264
bucket = self.create_test_bucket()
229265

0 commit comments

Comments
 (0)