10
10
import pandas as pd
11
11
from botocore .config import Config
12
12
13
- from awswrangler import _data_types , _utils
13
+ from awswrangler import _data_types , _utils , exceptions
14
14
15
15
_logger : logging .Logger = logging .getLogger (__name__ )
16
16
@@ -27,61 +27,111 @@ def _df2list(df: pd.DataFrame) -> List[List[Any]]:
27
27
return parameters
28
28
29
29
30
+ def _format_timestamp (timestamp : Union [int , datetime ]) -> str :
31
+ if isinstance (timestamp , int ):
32
+ return str (round (timestamp / 1_000_000 ))
33
+ if isinstance (timestamp , datetime ):
34
+ return str (round (timestamp .timestamp () * 1_000 ))
35
+ raise exceptions .InvalidArgumentType ("`time_col` must be of type timestamp." )
36
+
37
+
30
38
def _format_measure (measure_name : str , measure_value : Any , measure_type : str ) -> Dict [str , str ]:
31
39
return {
32
40
"Name" : measure_name ,
33
- "Value" : str ( round ( measure_value . timestamp () * 1_000 ) if measure_type == "TIMESTAMP" else measure_value ),
41
+ "Value" : _format_timestamp ( measure_value ) if measure_type == "TIMESTAMP" else str ( measure_value ),
34
42
"Type" : measure_type ,
35
43
}
36
44
37
45
46
+ def _sanitize_common_attributes (
47
+ common_attributes : Optional [Dict [str , Any ]],
48
+ version : int ,
49
+ measure_name : Optional [str ],
50
+ ) -> Dict [str , Any ]:
51
+ common_attributes = {} if not common_attributes else common_attributes
52
+ # Values in common_attributes take precedence
53
+ common_attributes .setdefault ("Version" , version )
54
+
55
+ if "Time" not in common_attributes :
56
+ # TimeUnit is MILLISECONDS by default for Timestream writes
57
+ # But if a time_col is supplied (i.e. Time is not in common_attributes)
58
+ # then TimeUnit must be set to MILLISECONDS explicitly
59
+ common_attributes ["TimeUnit" ] = "MILLISECONDS"
60
+
61
+ if "MeasureValue" in common_attributes and "MeasureValueType" not in common_attributes :
62
+ raise exceptions .InvalidArgumentCombination (
63
+ "MeasureValueType must be supplied alongside MeasureValue in common_attributes."
64
+ )
65
+
66
+ if measure_name :
67
+ common_attributes .setdefault ("MeasureName" , measure_name )
68
+ elif "MeasureName" not in common_attributes :
69
+ raise exceptions .InvalidArgumentCombination (
70
+ "MeasureName must be supplied with the `measure_name` argument or in common_attributes."
71
+ )
72
+ return common_attributes
73
+
74
+
38
75
def _write_batch (
76
+ timestream_client : boto3 .client ,
39
77
database : str ,
40
78
table : str ,
41
- cols_names : List [str ],
42
- measure_cols_names : List [str ],
79
+ common_attributes : Dict [str , Any ],
80
+ cols_names : List [Optional [str ]],
81
+ measure_cols : List [Optional [str ]],
43
82
measure_types : List [str ],
44
- version : int ,
83
+ dimension_cols : List [ Optional [ str ]] ,
45
84
batch : List [Any ],
46
- timestream_client : boto3 .client ,
47
- measure_name : Optional [str ] = None ,
48
85
) -> List [Dict [str , str ]]:
49
- try :
50
- time_loc = 0
51
- measure_cols_loc = 1
52
- dimensions_cols_loc = 1 + len (measure_cols_names )
53
- records : List [Dict [str , Any ]] = []
54
- for rec in batch :
55
- record : Dict [str , Any ] = {
56
- "Dimensions" : [
57
- {"Name" : name , "DimensionValueType" : "VARCHAR" , "Value" : str (value )}
58
- for name , value in zip (cols_names [dimensions_cols_loc :], rec [dimensions_cols_loc :])
59
- ],
60
- "Time" : str (round (rec [time_loc ].timestamp () * 1_000 )),
61
- "TimeUnit" : "MILLISECONDS" ,
62
- "Version" : version ,
63
- }
64
- if len (measure_cols_names ) == 1 :
65
- measure_value = rec [measure_cols_loc ]
66
- if pd .isnull (measure_value ):
67
- continue
68
- record ["MeasureName" ] = measure_name if measure_name else measure_cols_names [0 ]
69
- record ["MeasureValueType" ] = measure_types [0 ]
70
- record ["MeasureValue" ] = str (measure_value )
71
- else :
72
- record ["MeasureName" ] = measure_name if measure_name else measure_cols_names [0 ]
73
- record ["MeasureValueType" ] = "MULTI"
74
- record ["MeasureValues" ] = [
75
- _format_measure (measure_name , measure_value , measure_value_type )
76
- for measure_name , measure_value , measure_value_type in zip (
77
- measure_cols_names , rec [measure_cols_loc :dimensions_cols_loc ], measure_types
78
- )
79
- if not pd .isnull (measure_value )
80
- ]
81
- if len (record ["MeasureValues" ]) == 0 :
82
- continue
86
+ records : List [Dict [str , Any ]] = []
87
+ scalar = bool (len (measure_cols ) == 1 and "MeasureValues" not in common_attributes )
88
+ time_loc = 0
89
+ measure_cols_loc = 1 if cols_names [0 ] else 0
90
+ dimensions_cols_loc = 1 if len (measure_cols ) == 1 else 1 + len (measure_cols )
91
+ if all (cols_names ):
92
+ # Time and Measures are supplied in the data frame
93
+ dimensions_cols_loc = 1 + len (measure_cols )
94
+ elif all (v is None for v in cols_names [:2 ]):
95
+ # Time and Measures are supplied in common_attributes
96
+ dimensions_cols_loc = 0
97
+
98
+ for row in batch :
99
+ record : Dict [str , Any ] = {}
100
+ if "Time" not in common_attributes :
101
+ record ["Time" ] = _format_timestamp (row [time_loc ])
102
+ if scalar and "MeasureValue" not in common_attributes :
103
+ measure_value = row [measure_cols_loc ]
104
+ if pd .isnull (measure_value ):
105
+ continue
106
+ record ["MeasureValue" ] = str (measure_value )
107
+ elif not scalar and "MeasureValues" not in common_attributes :
108
+ record ["MeasureValues" ] = [
109
+ _format_measure (measure_name , measure_value , measure_value_type ) # type: ignore[arg-type]
110
+ for measure_name , measure_value , measure_value_type in zip (
111
+ measure_cols , row [measure_cols_loc :dimensions_cols_loc ], measure_types
112
+ )
113
+ if not pd .isnull (measure_value )
114
+ ]
115
+ if len (record ["MeasureValues" ]) == 0 :
116
+ continue
117
+ if "MeasureValueType" not in common_attributes :
118
+ record ["MeasureValueType" ] = measure_types [0 ] if scalar else "MULTI"
119
+ # Dimensions can be specified in both common_attributes and the data frame
120
+ dimensions = (
121
+ [
122
+ {"Name" : name , "DimensionValueType" : "VARCHAR" , "Value" : str (value )}
123
+ for name , value in zip (dimension_cols , row [dimensions_cols_loc :])
124
+ ]
125
+ if all (dimension_cols )
126
+ else []
127
+ )
128
+ if dimensions :
129
+ record ["Dimensions" ] = dimensions
130
+ if record :
83
131
records .append (record )
84
- if len (records ) > 0 :
132
+
133
+ try :
134
+ if records :
85
135
_utils .try_it (
86
136
f = timestream_client .write_records ,
87
137
ex = (
@@ -91,6 +141,7 @@ def _write_batch(
91
141
max_num_tries = 5 ,
92
142
DatabaseName = database ,
93
143
TableName = table ,
144
+ CommonAttributes = common_attributes ,
94
145
Records = records ,
95
146
)
96
147
except timestream_client .exceptions .RejectedRecordsException as ex :
@@ -192,12 +243,13 @@ def write(
192
243
df : pd .DataFrame ,
193
244
database : str ,
194
245
table : str ,
195
- time_col : str ,
196
- measure_col : Union [str , List [str ]],
197
- dimensions_cols : List [str ],
246
+ time_col : Optional [ str ] = None ,
247
+ measure_col : Union [str , List [Optional [ str ]], None ] = None ,
248
+ dimensions_cols : Optional [ List [Optional [ str ]]] = None ,
198
249
version : int = 1 ,
199
250
num_threads : int = 32 ,
200
251
measure_name : Optional [str ] = None ,
252
+ common_attributes : Optional [Dict [str , Any ]] = None ,
201
253
boto3_session : Optional [boto3 .Session ] = None ,
202
254
) -> List [Dict [str , str ]]:
203
255
"""Store a Pandas DataFrame into a Amazon Timestream table.
@@ -206,6 +258,16 @@ def write(
206
258
this function will not throw a Python exception.
207
259
Instead it will return the rejection information.
208
260
261
+ Note
262
+ ----
263
+ Values in `common_attributes` take precedence over all other arguments and data frame values.
264
+ Dimension attributes are merged with attributes in record objects.
265
+ Example: common_attributes = {"Dimensions": {"Name": "device_id", "Value": "12345"}, "MeasureValueType": "DOUBLE"}.
266
+
267
+ Note
268
+ ----
269
+ If the `time_col` column is supplied it must be of type timestamp. `TimeUnit` is set to MILLISECONDS by default.
270
+
209
271
Parameters
210
272
----------
211
273
df: pandas.DataFrame
@@ -214,18 +276,21 @@ def write(
214
276
Amazon Timestream database name.
215
277
table : str
216
278
Amazon Timestream table name.
217
- time_col : str
279
+ time_col : Optional[ str]
218
280
DataFrame column name to be used as time. MUST be a timestamp column.
219
- measure_col : Union[str, List[str]]
281
+ measure_col : Union[str, List[str], None ]
220
282
DataFrame column name(s) to be used as measure.
221
- dimensions_cols : List[str]
283
+ dimensions_cols : Optional[ List[str] ]
222
284
List of DataFrame column names to be used as dimensions.
223
285
version : int
224
286
Version number used for upserts.
225
287
Documentation https://docs.aws.amazon.com/timestream/latest/developerguide/API_WriteRecords.html.
226
288
measure_name : Optional[str]
227
289
Name that represents the data attribute of the time series.
228
290
Overrides ``measure_col`` if specified.
291
+ common_attributes : Optional[Dict[str, Any]]
292
+ Dictionary of attributes that is shared across all records in the request.
293
+ Using common attributes can optimize the cost of writes by reducing the size of request payloads.
229
294
num_threads : str
230
295
Number of thread to be used for concurrent writing.
231
296
boto3_session : boto3.Session(), optional
@@ -279,30 +344,50 @@ def write(
279
344
session = boto3_session ,
280
345
botocore_config = Config (read_timeout = 20 , max_pool_connections = 5000 , retries = {"max_attempts" : 10 }),
281
346
)
282
- measure_cols_names = measure_col if isinstance (measure_col , list ) else [measure_col ]
283
- _logger .debug ("measure_cols_names: %s" , measure_cols_names )
284
347
285
- measure_types : List [str ] = [
286
- _data_types .timestream_type_from_pandas (df [[measure_col_name ]]) for measure_col_name in measure_cols_names
348
+ measure_cols = measure_col if isinstance (measure_col , list ) else [measure_col ]
349
+ measure_types = [
350
+ _data_types .timestream_type_from_pandas (df [[measure_col_name ]])
351
+ for measure_col_name in measure_cols
352
+ if measure_col_name
287
353
]
288
- _logger .debug ("measure_types: %s" , measure_types )
289
- cols_names : List [str ] = [time_col ] + measure_cols_names + dimensions_cols
290
- _logger .debug ("cols_names: %s" , cols_names )
291
- batches : List [List [Any ]] = _utils .chunkify (lst = _df2list (df = df [cols_names ]), max_length = 100 )
292
- _logger .debug ("len(batches): %s" , len (batches ))
354
+ dimensions_cols = dimensions_cols if dimensions_cols else [dimensions_cols ] # type: ignore[list-item]
355
+ cols_names : List [Optional [str ]] = [time_col ] + measure_cols + dimensions_cols
356
+ measure_name = measure_name if measure_name else measure_cols [0 ]
357
+ common_attributes = _sanitize_common_attributes (common_attributes , version , measure_name )
358
+
359
+ _logger .debug (
360
+ "common_attributes: %s\n , cols_names: %s\n , measure_types: %s" ,
361
+ common_attributes ,
362
+ cols_names ,
363
+ measure_types ,
364
+ )
365
+
366
+ # User can supply arguments in one of two ways:
367
+ # 1. With the `common_attributes` dictionary which takes precedence
368
+ # 2. With data frame columns
369
+ # However, the data frame cannot be completely empty.
370
+ # So if all values in `cols_names` are None, an exception is raised.
371
+ if any (cols_names ):
372
+ batches : List [List [Any ]] = _utils .chunkify (lst = _df2list (df = df [[c for c in cols_names if c ]]), max_length = 100 )
373
+ else :
374
+ raise exceptions .InvalidArgumentCombination (
375
+ "At least one of `time_col`, `measure_col` or `dimensions_cols` must be specified."
376
+ )
377
+
293
378
with concurrent .futures .ThreadPoolExecutor (max_workers = num_threads ) as executor :
294
379
res : List [List [Any ]] = list (
295
380
executor .map (
296
381
_write_batch ,
382
+ itertools .repeat (timestream_client ),
297
383
itertools .repeat (database ),
298
384
itertools .repeat (table ),
385
+ itertools .repeat (common_attributes ),
299
386
itertools .repeat (cols_names ),
300
- itertools .repeat (measure_cols_names ),
387
+ itertools .repeat (measure_cols ),
301
388
itertools .repeat (measure_types ),
302
- itertools .repeat (version ),
389
+ itertools .repeat (dimensions_cols ),
303
390
batches ,
304
- itertools .repeat (timestream_client ),
305
- itertools .repeat (measure_name ),
306
391
)
307
392
)
308
393
return [item for sublist in res for item in sublist ]
0 commit comments