Skip to content

Commit e920f08

Browse files
feat(kafka): New Kafka utility (#6821)
* Adding support for Kafka Consumer - first commit * Adding support for Kafka Consumer - first commit * Adding exports * Refactoring functions * Refactoring functions * Adding docstring * Adding docstring * Adding docstring * Fix mypy stuff * Fix mypy stuff * Fix mypy stuff * Fix mypy stuff * Adding protobuf tests * Adding json tests * Adding docs * Adding docs * Internal refactoring * Internal refactoring * Cleaning up the PR * Renaming namespace * Refactoring tests * Refactoring tests * Make mypy happy * Refactoring tests * Refactoring tests * Refactoring tests * Refactoring tests * Refactoring tests * Refactoring tests * Refactoring tests * Refactoring tests * Refactoring tests * Refactoring tests * Refactoring tests
1 parent 3a9a0e8 commit e920f08

39 files changed

+2768
-427
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def __init__(
407407

408408
# OpenAPI spec only understands paths with { }. So we'll have to convert Powertools' < >.
409409
# https://swagger.io/specification/#path-templating
410-
self.openapi_path = re.sub(r"<(.*?)>", lambda m: f"{{{''.join(m.group(1))}}}", self.path)
410+
self.openapi_path = re.sub(r"<(.*?)>", lambda m: f"{{{''.join(m.group(1))}}}", self.path) # type: ignore[arg-type]
411411

412412
self.rule = rule
413413
self.func = func

aws_lambda_powertools/utilities/data_classes/kafka_event.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,19 @@
1010
from collections.abc import Iterator
1111

1212

13-
class KafkaEventRecord(DictWrapper):
13+
class KafkaEventRecordSchemaMetadata(DictWrapper):
14+
@property
15+
def data_format(self) -> str | None:
16+
"""The data format of the Kafka record."""
17+
return self.get("dataFormat", None)
18+
19+
@property
20+
def schema_id(self) -> str | None:
21+
"""The schema id of the Kafka record."""
22+
return self.get("schemaId", None)
23+
24+
25+
class KafkaEventRecordBase(DictWrapper):
1426
@property
1527
def topic(self) -> str:
1628
"""The Kafka topic."""
@@ -36,6 +48,24 @@ def timestamp_type(self) -> str:
3648
"""The Kafka record timestamp type."""
3749
return self["timestampType"]
3850

51+
@property
52+
def key_schema_metadata(self) -> KafkaEventRecordSchemaMetadata | None:
53+
"""The metadata of the Key Kafka record."""
54+
return (
55+
None if self.get("keySchemaMetadata") is None else KafkaEventRecordSchemaMetadata(self["keySchemaMetadata"])
56+
)
57+
58+
@property
59+
def value_schema_metadata(self) -> KafkaEventRecordSchemaMetadata | None:
60+
"""The metadata of the Value Kafka record."""
61+
return (
62+
None
63+
if self.get("valueSchemaMetadata") is None
64+
else KafkaEventRecordSchemaMetadata(self["valueSchemaMetadata"])
65+
)
66+
67+
68+
class KafkaEventRecord(KafkaEventRecordBase):
3969
@property
4070
def key(self) -> str | None:
4171
"""
@@ -83,18 +113,7 @@ def decoded_headers(self) -> dict[str, bytes]:
83113
return CaseInsensitiveDict((k, bytes(v)) for chunk in self.headers for k, v in chunk.items())
84114

85115

86-
class KafkaEvent(DictWrapper):
87-
"""Self-managed or MSK Apache Kafka event trigger
88-
Documentation:
89-
--------------
90-
- https://docs.aws.amazon.com/lambda/latest/dg/with-kafka.html
91-
- https://docs.aws.amazon.com/lambda/latest/dg/with-msk.html
92-
"""
93-
94-
def __init__(self, data: dict[str, Any]):
95-
super().__init__(data)
96-
self._records: Iterator[KafkaEventRecord] | None = None
97-
116+
class KafkaEventBase(DictWrapper):
98117
@property
99118
def event_source(self) -> str:
100119
"""The AWS service from which the Kafka event record originated."""
@@ -115,6 +134,19 @@ def decoded_bootstrap_servers(self) -> list[str]:
115134
"""The decoded Kafka bootstrap URL."""
116135
return self.bootstrap_servers.split(",")
117136

137+
138+
class KafkaEvent(KafkaEventBase):
139+
"""Self-managed or MSK Apache Kafka event trigger
140+
Documentation:
141+
--------------
142+
- https://docs.aws.amazon.com/lambda/latest/dg/with-kafka.html
143+
- https://docs.aws.amazon.com/lambda/latest/dg/with-msk.html
144+
"""
145+
146+
def __init__(self, data: dict[str, Any]):
147+
super().__init__(data)
148+
self._records: Iterator[KafkaEventRecord] | None = None
149+
118150
@property
119151
def records(self) -> Iterator[KafkaEventRecord]:
120152
"""The Kafka records."""
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from aws_lambda_powertools.utilities.kafka.consumer_records import ConsumerRecords
2+
from aws_lambda_powertools.utilities.kafka.kafka_consumer import kafka_consumer
3+
from aws_lambda_powertools.utilities.kafka.schema_config import SchemaConfig
4+
5+
__all__ = [
6+
"kafka_consumer",
7+
"ConsumerRecords",
8+
"SchemaConfig",
9+
]
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from __future__ import annotations
2+
3+
from functools import cached_property
4+
from typing import TYPE_CHECKING, Any
5+
6+
from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict
7+
from aws_lambda_powertools.utilities.data_classes.kafka_event import KafkaEventBase, KafkaEventRecordBase
8+
from aws_lambda_powertools.utilities.kafka.deserializer.deserializer import get_deserializer
9+
from aws_lambda_powertools.utilities.kafka.serialization.serialization import serialize_to_output_type
10+
11+
if TYPE_CHECKING:
12+
from collections.abc import Iterator
13+
14+
from aws_lambda_powertools.utilities.kafka.schema_config import SchemaConfig
15+
16+
17+
class ConsumerRecordRecords(KafkaEventRecordBase):
18+
"""
19+
A Kafka Consumer Record
20+
"""
21+
22+
def __init__(self, data: dict[str, Any], schema_config: SchemaConfig | None = None):
23+
super().__init__(data)
24+
self.schema_config = schema_config
25+
26+
@cached_property
27+
def key(self) -> Any:
28+
key = self.get("key")
29+
30+
# Return None if key doesn't exist
31+
if not key:
32+
return None
33+
34+
# Determine schema type and schema string
35+
schema_type = None
36+
schema_str = None
37+
output_serializer = None
38+
39+
if self.schema_config and self.schema_config.key_schema_type:
40+
schema_type = self.schema_config.key_schema_type
41+
schema_str = self.schema_config.key_schema
42+
output_serializer = self.schema_config.key_output_serializer
43+
44+
# Always use get_deserializer if None it will default to DEFAULT
45+
deserializer = get_deserializer(schema_type, schema_str)
46+
deserialized_value = deserializer.deserialize(key)
47+
48+
# Apply output serializer if specified
49+
if output_serializer:
50+
return serialize_to_output_type(deserialized_value, output_serializer)
51+
52+
return deserialized_value
53+
54+
@cached_property
55+
def value(self) -> Any:
56+
value = self["value"]
57+
58+
# Determine schema type and schema string
59+
schema_type = None
60+
schema_str = None
61+
output_serializer = None
62+
63+
if self.schema_config and self.schema_config.value_schema_type:
64+
schema_type = self.schema_config.value_schema_type
65+
schema_str = self.schema_config.value_schema
66+
output_serializer = self.schema_config.value_output_serializer
67+
68+
# Always use get_deserializer if None it will default to DEFAULT
69+
deserializer = get_deserializer(schema_type, schema_str)
70+
deserialized_value = deserializer.deserialize(value)
71+
72+
# Apply output serializer if specified
73+
if output_serializer:
74+
return serialize_to_output_type(deserialized_value, output_serializer)
75+
76+
return deserialized_value
77+
78+
@property
79+
def original_value(self) -> str:
80+
"""The original (base64 encoded) Kafka record value."""
81+
return self["value"]
82+
83+
@property
84+
def original_key(self) -> str | None:
85+
"""
86+
The original (base64 encoded) Kafka record key.
87+
88+
This key is optional; if not provided,
89+
a round-robin algorithm will be used to determine
90+
the partition for the message.
91+
"""
92+
93+
return self.get("key")
94+
95+
@property
96+
def original_headers(self) -> list[dict[str, list[int]]]:
97+
"""The raw Kafka record headers."""
98+
return self["headers"]
99+
100+
@cached_property
101+
def headers(self) -> dict[str, bytes]:
102+
"""Decodes the headers as a single dictionary."""
103+
return CaseInsensitiveDict((k, bytes(v)) for chunk in self.original_headers for k, v in chunk.items())
104+
105+
106+
class ConsumerRecords(KafkaEventBase):
107+
"""Self-managed or MSK Apache Kafka event trigger
108+
Documentation:
109+
--------------
110+
- https://docs.aws.amazon.com/lambda/latest/dg/with-kafka.html
111+
- https://docs.aws.amazon.com/lambda/latest/dg/with-msk.html
112+
"""
113+
114+
def __init__(self, data: dict[str, Any], schema_config: SchemaConfig | None = None):
115+
super().__init__(data)
116+
self._records: Iterator[ConsumerRecordRecords] | None = None
117+
self.schema_config = schema_config
118+
119+
@property
120+
def records(self) -> Iterator[ConsumerRecordRecords]:
121+
"""The Kafka records."""
122+
for chunk in self["records"].values():
123+
for record in chunk:
124+
yield ConsumerRecordRecords(data=record, schema_config=self.schema_config)
125+
126+
@property
127+
def record(self) -> ConsumerRecordRecords:
128+
"""
129+
Returns the next Kafka record using an iterator.
130+
131+
Returns
132+
-------
133+
ConsumerRecordRecords
134+
The next Kafka record.
135+
136+
Raises
137+
------
138+
StopIteration
139+
If there are no more records available.
140+
141+
"""
142+
if self._records is None:
143+
self._records = self.records
144+
return next(self._records)

aws_lambda_powertools/utilities/kafka/deserializer/__init__.py

Whitespace-only changes.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from __future__ import annotations
2+
3+
import io
4+
5+
from avro.io import BinaryDecoder, DatumReader
6+
from avro.schema import parse as parse_schema
7+
8+
from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase
9+
from aws_lambda_powertools.utilities.kafka.exceptions import (
10+
KafkaConsumerAvroSchemaParserError,
11+
KafkaConsumerDeserializationError,
12+
)
13+
14+
15+
class AvroDeserializer(DeserializerBase):
16+
"""
17+
Deserializer for Apache Avro formatted data.
18+
19+
This class provides functionality to deserialize Avro binary data using
20+
a provided Avro schema definition.
21+
"""
22+
23+
def __init__(self, schema_str: str):
24+
try:
25+
self.parsed_schema = parse_schema(schema_str)
26+
self.reader = DatumReader(self.parsed_schema)
27+
except Exception as e:
28+
raise KafkaConsumerAvroSchemaParserError(
29+
f"Invalid Avro schema. Please ensure the provided avro schema is valid: {type(e).__name__}: {str(e)}",
30+
) from e
31+
32+
def deserialize(self, data: bytes | str) -> object:
33+
"""
34+
Deserialize Avro binary data to a Python dictionary.
35+
36+
Parameters
37+
----------
38+
data : bytes or str
39+
The Avro binary data to deserialize. If provided as a string,
40+
it will be decoded to bytes first.
41+
42+
Returns
43+
-------
44+
dict[str, Any]
45+
Deserialized data as a dictionary.
46+
47+
Raises
48+
------
49+
KafkaConsumerDeserializationError
50+
When the data cannot be deserialized according to the schema,
51+
typically due to data format incompatibility.
52+
53+
Examples
54+
--------
55+
>>> deserializer = AvroDeserializer(schema_str)
56+
>>> avro_data = b'...' # binary Avro data
57+
>>> try:
58+
... result = deserializer.deserialize(avro_data)
59+
... # Process the deserialized data
60+
... except KafkaConsumerDeserializationError as e:
61+
... print(f"Failed to deserialize: {e}")
62+
"""
63+
try:
64+
value = self._decode_input(data)
65+
bytes_reader = io.BytesIO(value)
66+
decoder = BinaryDecoder(bytes_reader)
67+
return self.reader.read(decoder)
68+
except Exception as e:
69+
raise KafkaConsumerDeserializationError(
70+
f"Error trying to deserialize avro data - {type(e).__name__}: {str(e)}",
71+
) from e
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
from abc import ABC, abstractmethod
5+
from typing import Any
6+
7+
8+
class DeserializerBase(ABC):
9+
"""
10+
Abstract base class for deserializers.
11+
12+
This class defines the interface for all deserializers in the Kafka consumer utility
13+
and provides a common method for decoding input data.
14+
15+
Methods
16+
-------
17+
deserialize(data)
18+
Abstract method that must be implemented by subclasses to deserialize data.
19+
_decode_input(data)
20+
Helper method to decode input data to bytes.
21+
22+
Examples
23+
--------
24+
>>> class MyDeserializer(DeserializerBase):
25+
... def deserialize(self, data: bytes | str) -> dict[str, Any]:
26+
... value = self._decode_input(data)
27+
... # Custom deserialization logic here
28+
... return {"key": "value"}
29+
"""
30+
31+
@abstractmethod
32+
def deserialize(self, data: str) -> dict[str, Any] | str | object:
33+
"""
34+
Deserialize input data to a Python dictionary.
35+
36+
This abstract method must be implemented by subclasses to provide
37+
specific deserialization logic.
38+
39+
Parameters
40+
----------
41+
data : str
42+
The data to deserialize, it's always a base64 encoded string
43+
44+
Returns
45+
-------
46+
dict[str, Any]
47+
The deserialized data as a dictionary.
48+
"""
49+
raise NotImplementedError("Subclasses must implement the deserialize method")
50+
51+
def _decode_input(self, data: bytes | str) -> bytes:
52+
return base64.b64decode(data)

0 commit comments

Comments
 (0)