diff --git a/aws_lambda_powertools/utilities/data_classes/kafka_event.py b/aws_lambda_powertools/utilities/data_classes/kafka_event.py index 88062f38e56..c3d549c0f49 100644 --- a/aws_lambda_powertools/utilities/data_classes/kafka_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kafka_event.py @@ -37,14 +37,25 @@ def timestamp_type(self) -> str: return self["timestampType"] @property - def key(self) -> str: - """The raw (base64 encoded) Kafka record key.""" - return self["key"] + def key(self) -> str | None: + """ + The raw (base64 encoded) Kafka record key. + + This key is optional; if not provided, + a round-robin algorithm will be used to determine + the partition for the message. + """ + + return self.get("key") @property - def decoded_key(self) -> bytes: - """Decode the base64 encoded key as bytes.""" - return base64.b64decode(self.key) + def decoded_key(self) -> bytes | None: + """ + Decode the base64 encoded key as bytes. + + If the key is not provided, this will return None. + """ + return None if self.key is None else base64.b64decode(self.key) @property def value(self) -> str: diff --git a/aws_lambda_powertools/utilities/parser/models/kafka.py b/aws_lambda_powertools/utilities/parser/models/kafka.py index 447b96c406b..c365c51c63c 100644 --- a/aws_lambda_powertools/utilities/parser/models/kafka.py +++ b/aws_lambda_powertools/utilities/parser/models/kafka.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, List, Literal, Type, Union +from typing import Dict, List, Literal, Optional, Type, Union from pydantic import BaseModel, field_validator @@ -14,12 +14,16 @@ class KafkaRecordModel(BaseModel): offset: int timestamp: datetime timestampType: str - key: bytes + key: Optional[bytes] = None value: Union[str, Type[BaseModel]] headers: List[Dict[str, bytes]] - # Added type ignore to keep compatibility between Pydantic v1 and v2 - _decode_key = field_validator("key")(base64_decode) # type: ignore[type-var, unused-ignore] + # key is optional; only decode if not None + @field_validator("key", mode="before") + def decode_key(cls, value): + if value is not None: + return base64_decode(value) + return value @field_validator("value", mode="before") def data_base64_decode(cls, value): diff --git a/tests/events/kafkaEventMsk.json b/tests/events/kafkaEventMsk.json index 5a35b89680a..f0c7d36c2cf 100644 --- a/tests/events/kafkaEventMsk.json +++ b/tests/events/kafkaEventMsk.json @@ -29,6 +29,57 @@ ] } ] + }, + { + "topic":"mytopic", + "partition":0, + "offset":15, + "timestamp":1545084650987, + "timestampType":"CREATE_TIME", + "value":"eyJrZXkiOiJ2YWx1ZSJ9", + "headers":[ + { + "headerKey":[ + 104, + 101, + 97, + 100, + 101, + 114, + 86, + 97, + 108, + 117, + 101 + ] + } + ] + }, + { + "topic":"mytopic", + "partition":0, + "offset":15, + "timestamp":1545084650987, + "timestampType":"CREATE_TIME", + "key": null, + "value":"eyJrZXkiOiJ2YWx1ZSJ9", + "headers":[ + { + "headerKey":[ + 104, + 101, + 97, + 100, + 101, + 114, + 86, + 97, + 108, + 117, + 101 + ] + } + ] } ] } diff --git a/tests/events/kafkaEventSelfManaged.json b/tests/events/kafkaEventSelfManaged.json index eaf0bf34cae..f99ca35cc48 100644 --- a/tests/events/kafkaEventSelfManaged.json +++ b/tests/events/kafkaEventSelfManaged.json @@ -28,6 +28,57 @@ ] } ] + }, + { + "topic": "mytopic", + "partition": 0, + "offset": 15, + "timestamp": 1545084650987, + "timestampType": "CREATE_TIME", + "value": "eyJrZXkiOiJ2YWx1ZSJ9", + "headers": [ + { + "headerKey": [ + 104, + 101, + 97, + 100, + 101, + 114, + 86, + 97, + 108, + 117, + 101 + ] + } + ] + }, + { + "topic": "mytopic", + "partition": 0, + "offset": 15, + "timestamp": 1545084650987, + "timestampType": "CREATE_TIME", + "key": null, + "value": "eyJrZXkiOiJ2YWx1ZSJ9", + "headers": [ + { + "headerKey": [ + 104, + 101, + 97, + 100, + 101, + 114, + 86, + 97, + 108, + 117, + 101 + ] + } + ] } ] } diff --git a/tests/unit/data_classes/required_dependencies/test_kafka_event.py b/tests/unit/data_classes/required_dependencies/test_kafka_event.py index b03c712f52c..8e4480a06d7 100644 --- a/tests/unit/data_classes/required_dependencies/test_kafka_event.py +++ b/tests/unit/data_classes/required_dependencies/test_kafka_event.py @@ -21,7 +21,7 @@ def test_kafka_msk_event(): assert parsed_event.decoded_bootstrap_servers == bootstrap_servers_list records = list(parsed_event.records) - assert len(records) == 1 + assert len(records) == 3 record = records[0] raw_record = raw_event["records"]["mytopic-0"][0] assert record.topic == raw_record["topic"] @@ -36,6 +36,9 @@ def test_kafka_msk_event(): assert record.decoded_headers["HeaderKey"] == b"headerValue" assert parsed_event.record == records[0] + for i in range(1, 3): + record = records[i] + assert record.key is None def test_kafka_self_managed_event(): @@ -52,7 +55,7 @@ def test_kafka_self_managed_event(): assert parsed_event.decoded_bootstrap_servers == bootstrap_servers_list records = list(parsed_event.records) - assert len(records) == 1 + assert len(records) == 3 record = records[0] raw_record = raw_event["records"]["mytopic-0"][0] assert record.topic == raw_record["topic"] @@ -68,14 +71,18 @@ def test_kafka_self_managed_event(): assert parsed_event.record == records[0] + for i in range(1, 3): + record = records[i] + assert record.key is None + def test_kafka_record_property_with_stopiteration_error(): # GIVEN a kafka event with one record raw_event = load_event("kafkaEventMsk.json") parsed_event = KafkaEvent(raw_event) - # WHEN calling record property twice + # WHEN calling record property thrice # THEN raise StopIteration with pytest.raises(StopIteration): - assert parsed_event.record.topic is not None - assert parsed_event.record.partition is not None + for _ in range(4): + assert parsed_event.record.topic is not None diff --git a/tests/unit/parser/_pydantic/test_kafka.py b/tests/unit/parser/_pydantic/test_kafka.py index 066820c2f11..aabb669b805 100644 --- a/tests/unit/parser/_pydantic/test_kafka.py +++ b/tests/unit/parser/_pydantic/test_kafka.py @@ -15,9 +15,9 @@ def test_kafka_msk_event_with_envelope(): model=MyLambdaKafkaBusiness, envelope=envelopes.KafkaEnvelope, ) - - assert parsed_event[0].key == "value" - assert len(parsed_event) == 1 + for i in range(3): + assert parsed_event[i].key == "value" + assert len(parsed_event) == 3 def test_kafka_self_managed_event_with_envelope(): @@ -27,9 +27,9 @@ def test_kafka_self_managed_event_with_envelope(): model=MyLambdaKafkaBusiness, envelope=envelopes.KafkaEnvelope, ) - - assert parsed_event[0].key == "value" - assert len(parsed_event) == 1 + for i in range(3): + assert parsed_event[i].key == "value" + assert len(parsed_event) == 3 def test_self_managed_kafka_event(): @@ -41,7 +41,7 @@ def test_self_managed_kafka_event(): assert parsed_event.bootstrapServers == raw_event["bootstrapServers"].split(",") records = list(parsed_event.records["mytopic-0"]) - assert len(records) == 1 + assert len(records) == 3 record: KafkaRecordModel = records[0] raw_record = raw_event["records"]["mytopic-0"][0] assert record.topic == raw_record["topic"] @@ -55,6 +55,8 @@ def test_self_managed_kafka_event(): assert record.value == '{"key":"value"}' assert len(record.headers) == 1 assert record.headers[0]["headerKey"] == b"headerValue" + record: KafkaRecordModel = records[1] + assert record.key is None def test_kafka_msk_event(): @@ -66,7 +68,7 @@ def test_kafka_msk_event(): assert parsed_event.eventSourceArn == raw_event["eventSourceArn"] records = list(parsed_event.records["mytopic-0"]) - assert len(records) == 1 + assert len(records) == 3 record: KafkaRecordModel = records[0] raw_record = raw_event["records"]["mytopic-0"][0] assert record.topic == raw_record["topic"] @@ -80,3 +82,6 @@ def test_kafka_msk_event(): assert record.value == '{"key":"value"}' assert len(record.headers) == 1 assert record.headers[0]["headerKey"] == b"headerValue" + for i in range(1, 3): + record: KafkaRecordModel = records[i] + assert record.key is None