Skip to content

Commit fb1d228

Browse files
authored
enhancement: Add timestream common attributes (#2091)
* enhancement: Add timestream common attributes
1 parent 94b076c commit fb1d228

File tree

2 files changed

+304
-62
lines changed

2 files changed

+304
-62
lines changed

awswrangler/timestream.py

Lines changed: 146 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pandas as pd
1111
from botocore.config import Config
1212

13-
from awswrangler import _data_types, _utils
13+
from awswrangler import _data_types, _utils, exceptions
1414

1515
_logger: logging.Logger = logging.getLogger(__name__)
1616

@@ -27,61 +27,111 @@ def _df2list(df: pd.DataFrame) -> List[List[Any]]:
2727
return parameters
2828

2929

30+
def _format_timestamp(timestamp: Union[int, datetime]) -> str:
31+
if isinstance(timestamp, int):
32+
return str(round(timestamp / 1_000_000))
33+
if isinstance(timestamp, datetime):
34+
return str(round(timestamp.timestamp() * 1_000))
35+
raise exceptions.InvalidArgumentType("`time_col` must be of type timestamp.")
36+
37+
3038
def _format_measure(measure_name: str, measure_value: Any, measure_type: str) -> Dict[str, str]:
3139
return {
3240
"Name": measure_name,
33-
"Value": str(round(measure_value.timestamp() * 1_000) if measure_type == "TIMESTAMP" else measure_value),
41+
"Value": _format_timestamp(measure_value) if measure_type == "TIMESTAMP" else str(measure_value),
3442
"Type": measure_type,
3543
}
3644

3745

46+
def _sanitize_common_attributes(
47+
common_attributes: Optional[Dict[str, Any]],
48+
version: int,
49+
measure_name: Optional[str],
50+
) -> Dict[str, Any]:
51+
common_attributes = {} if not common_attributes else common_attributes
52+
# Values in common_attributes take precedence
53+
common_attributes.setdefault("Version", version)
54+
55+
if "Time" not in common_attributes:
56+
# TimeUnit is MILLISECONDS by default for Timestream writes
57+
# But if a time_col is supplied (i.e. Time is not in common_attributes)
58+
# then TimeUnit must be set to MILLISECONDS explicitly
59+
common_attributes["TimeUnit"] = "MILLISECONDS"
60+
61+
if "MeasureValue" in common_attributes and "MeasureValueType" not in common_attributes:
62+
raise exceptions.InvalidArgumentCombination(
63+
"MeasureValueType must be supplied alongside MeasureValue in common_attributes."
64+
)
65+
66+
if measure_name:
67+
common_attributes.setdefault("MeasureName", measure_name)
68+
elif "MeasureName" not in common_attributes:
69+
raise exceptions.InvalidArgumentCombination(
70+
"MeasureName must be supplied with the `measure_name` argument or in common_attributes."
71+
)
72+
return common_attributes
73+
74+
3875
def _write_batch(
76+
timestream_client: boto3.client,
3977
database: str,
4078
table: str,
41-
cols_names: List[str],
42-
measure_cols_names: List[str],
79+
common_attributes: Dict[str, Any],
80+
cols_names: List[Optional[str]],
81+
measure_cols: List[Optional[str]],
4382
measure_types: List[str],
44-
version: int,
83+
dimension_cols: List[Optional[str]],
4584
batch: List[Any],
46-
timestream_client: boto3.client,
47-
measure_name: Optional[str] = None,
4885
) -> List[Dict[str, str]]:
49-
try:
50-
time_loc = 0
51-
measure_cols_loc = 1
52-
dimensions_cols_loc = 1 + len(measure_cols_names)
53-
records: List[Dict[str, Any]] = []
54-
for rec in batch:
55-
record: Dict[str, Any] = {
56-
"Dimensions": [
57-
{"Name": name, "DimensionValueType": "VARCHAR", "Value": str(value)}
58-
for name, value in zip(cols_names[dimensions_cols_loc:], rec[dimensions_cols_loc:])
59-
],
60-
"Time": str(round(rec[time_loc].timestamp() * 1_000)),
61-
"TimeUnit": "MILLISECONDS",
62-
"Version": version,
63-
}
64-
if len(measure_cols_names) == 1:
65-
measure_value = rec[measure_cols_loc]
66-
if pd.isnull(measure_value):
67-
continue
68-
record["MeasureName"] = measure_name if measure_name else measure_cols_names[0]
69-
record["MeasureValueType"] = measure_types[0]
70-
record["MeasureValue"] = str(measure_value)
71-
else:
72-
record["MeasureName"] = measure_name if measure_name else measure_cols_names[0]
73-
record["MeasureValueType"] = "MULTI"
74-
record["MeasureValues"] = [
75-
_format_measure(measure_name, measure_value, measure_value_type)
76-
for measure_name, measure_value, measure_value_type in zip(
77-
measure_cols_names, rec[measure_cols_loc:dimensions_cols_loc], measure_types
78-
)
79-
if not pd.isnull(measure_value)
80-
]
81-
if len(record["MeasureValues"]) == 0:
82-
continue
86+
records: List[Dict[str, Any]] = []
87+
scalar = bool(len(measure_cols) == 1 and "MeasureValues" not in common_attributes)
88+
time_loc = 0
89+
measure_cols_loc = 1 if cols_names[0] else 0
90+
dimensions_cols_loc = 1 if len(measure_cols) == 1 else 1 + len(measure_cols)
91+
if all(cols_names):
92+
# Time and Measures are supplied in the data frame
93+
dimensions_cols_loc = 1 + len(measure_cols)
94+
elif all(v is None for v in cols_names[:2]):
95+
# Time and Measures are supplied in common_attributes
96+
dimensions_cols_loc = 0
97+
98+
for row in batch:
99+
record: Dict[str, Any] = {}
100+
if "Time" not in common_attributes:
101+
record["Time"] = _format_timestamp(row[time_loc])
102+
if scalar and "MeasureValue" not in common_attributes:
103+
measure_value = row[measure_cols_loc]
104+
if pd.isnull(measure_value):
105+
continue
106+
record["MeasureValue"] = str(measure_value)
107+
elif not scalar and "MeasureValues" not in common_attributes:
108+
record["MeasureValues"] = [
109+
_format_measure(measure_name, measure_value, measure_value_type) # type: ignore[arg-type]
110+
for measure_name, measure_value, measure_value_type in zip(
111+
measure_cols, row[measure_cols_loc:dimensions_cols_loc], measure_types
112+
)
113+
if not pd.isnull(measure_value)
114+
]
115+
if len(record["MeasureValues"]) == 0:
116+
continue
117+
if "MeasureValueType" not in common_attributes:
118+
record["MeasureValueType"] = measure_types[0] if scalar else "MULTI"
119+
# Dimensions can be specified in both common_attributes and the data frame
120+
dimensions = (
121+
[
122+
{"Name": name, "DimensionValueType": "VARCHAR", "Value": str(value)}
123+
for name, value in zip(dimension_cols, row[dimensions_cols_loc:])
124+
]
125+
if all(dimension_cols)
126+
else []
127+
)
128+
if dimensions:
129+
record["Dimensions"] = dimensions
130+
if record:
83131
records.append(record)
84-
if len(records) > 0:
132+
133+
try:
134+
if records:
85135
_utils.try_it(
86136
f=timestream_client.write_records,
87137
ex=(
@@ -91,6 +141,7 @@ def _write_batch(
91141
max_num_tries=5,
92142
DatabaseName=database,
93143
TableName=table,
144+
CommonAttributes=common_attributes,
94145
Records=records,
95146
)
96147
except timestream_client.exceptions.RejectedRecordsException as ex:
@@ -192,12 +243,13 @@ def write(
192243
df: pd.DataFrame,
193244
database: str,
194245
table: str,
195-
time_col: str,
196-
measure_col: Union[str, List[str]],
197-
dimensions_cols: List[str],
246+
time_col: Optional[str] = None,
247+
measure_col: Union[str, List[Optional[str]], None] = None,
248+
dimensions_cols: Optional[List[Optional[str]]] = None,
198249
version: int = 1,
199250
num_threads: int = 32,
200251
measure_name: Optional[str] = None,
252+
common_attributes: Optional[Dict[str, Any]] = None,
201253
boto3_session: Optional[boto3.Session] = None,
202254
) -> List[Dict[str, str]]:
203255
"""Store a Pandas DataFrame into a Amazon Timestream table.
@@ -206,6 +258,16 @@ def write(
206258
this function will not throw a Python exception.
207259
Instead it will return the rejection information.
208260
261+
Note
262+
----
263+
Values in `common_attributes` take precedence over all other arguments and data frame values.
264+
Dimension attributes are merged with attributes in record objects.
265+
Example: common_attributes = {"Dimensions": {"Name": "device_id", "Value": "12345"}, "MeasureValueType": "DOUBLE"}.
266+
267+
Note
268+
----
269+
If the `time_col` column is supplied it must be of type timestamp. `TimeUnit` is set to MILLISECONDS by default.
270+
209271
Parameters
210272
----------
211273
df: pandas.DataFrame
@@ -214,18 +276,21 @@ def write(
214276
Amazon Timestream database name.
215277
table : str
216278
Amazon Timestream table name.
217-
time_col : str
279+
time_col : Optional[str]
218280
DataFrame column name to be used as time. MUST be a timestamp column.
219-
measure_col : Union[str, List[str]]
281+
measure_col : Union[str, List[str], None]
220282
DataFrame column name(s) to be used as measure.
221-
dimensions_cols : List[str]
283+
dimensions_cols : Optional[List[str]]
222284
List of DataFrame column names to be used as dimensions.
223285
version : int
224286
Version number used for upserts.
225287
Documentation https://docs.aws.amazon.com/timestream/latest/developerguide/API_WriteRecords.html.
226288
measure_name : Optional[str]
227289
Name that represents the data attribute of the time series.
228290
Overrides ``measure_col`` if specified.
291+
common_attributes : Optional[Dict[str, Any]]
292+
Dictionary of attributes that is shared across all records in the request.
293+
Using common attributes can optimize the cost of writes by reducing the size of request payloads.
229294
num_threads : str
230295
Number of thread to be used for concurrent writing.
231296
boto3_session : boto3.Session(), optional
@@ -279,30 +344,50 @@ def write(
279344
session=boto3_session,
280345
botocore_config=Config(read_timeout=20, max_pool_connections=5000, retries={"max_attempts": 10}),
281346
)
282-
measure_cols_names = measure_col if isinstance(measure_col, list) else [measure_col]
283-
_logger.debug("measure_cols_names: %s", measure_cols_names)
284347

285-
measure_types: List[str] = [
286-
_data_types.timestream_type_from_pandas(df[[measure_col_name]]) for measure_col_name in measure_cols_names
348+
measure_cols = measure_col if isinstance(measure_col, list) else [measure_col]
349+
measure_types = [
350+
_data_types.timestream_type_from_pandas(df[[measure_col_name]])
351+
for measure_col_name in measure_cols
352+
if measure_col_name
287353
]
288-
_logger.debug("measure_types: %s", measure_types)
289-
cols_names: List[str] = [time_col] + measure_cols_names + dimensions_cols
290-
_logger.debug("cols_names: %s", cols_names)
291-
batches: List[List[Any]] = _utils.chunkify(lst=_df2list(df=df[cols_names]), max_length=100)
292-
_logger.debug("len(batches): %s", len(batches))
354+
dimensions_cols = dimensions_cols if dimensions_cols else [dimensions_cols] # type: ignore[list-item]
355+
cols_names: List[Optional[str]] = [time_col] + measure_cols + dimensions_cols
356+
measure_name = measure_name if measure_name else measure_cols[0]
357+
common_attributes = _sanitize_common_attributes(common_attributes, version, measure_name)
358+
359+
_logger.debug(
360+
"common_attributes: %s\n, cols_names: %s\n, measure_types: %s",
361+
common_attributes,
362+
cols_names,
363+
measure_types,
364+
)
365+
366+
# User can supply arguments in one of two ways:
367+
# 1. With the `common_attributes` dictionary which takes precedence
368+
# 2. With data frame columns
369+
# However, the data frame cannot be completely empty.
370+
# So if all values in `cols_names` are None, an exception is raised.
371+
if any(cols_names):
372+
batches: List[List[Any]] = _utils.chunkify(lst=_df2list(df=df[[c for c in cols_names if c]]), max_length=100)
373+
else:
374+
raise exceptions.InvalidArgumentCombination(
375+
"At least one of `time_col`, `measure_col` or `dimensions_cols` must be specified."
376+
)
377+
293378
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
294379
res: List[List[Any]] = list(
295380
executor.map(
296381
_write_batch,
382+
itertools.repeat(timestream_client),
297383
itertools.repeat(database),
298384
itertools.repeat(table),
385+
itertools.repeat(common_attributes),
299386
itertools.repeat(cols_names),
300-
itertools.repeat(measure_cols_names),
387+
itertools.repeat(measure_cols),
301388
itertools.repeat(measure_types),
302-
itertools.repeat(version),
389+
itertools.repeat(dimensions_cols),
303390
batches,
304-
itertools.repeat(timestream_client),
305-
itertools.repeat(measure_name),
306391
)
307392
)
308393
return [item for sublist in res for item in sublist]

0 commit comments

Comments
 (0)