Skip to content

fix(batch): handle early validation errors for pydantic models (poison pill) #2091 #2099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 7, 2023
Merged
43 changes: 38 additions & 5 deletions aws_lambda_powertools/utilities/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
KinesisStreamRecord,
)
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.parser import ValidationError
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -316,21 +317,36 @@ def _get_messages_to_report(self) -> List[Dict[str, str]]:
def _collect_sqs_failures(self):
failures = []
for msg in self.fail_messages:
msg_id = msg.messageId if self.model else msg.message_id
# If a message failed due to model validation (e.g., poison pill)
# we convert to an event source data class...but self.model is still true
# therefore, we do an additional check on whether the failed message is still a model
# see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091
if self.model and getattr(msg, "parse_obj", None):
msg_id = msg.messageId
else:
msg_id = msg.message_id
failures.append({"itemIdentifier": msg_id})
return failures

def _collect_kinesis_failures(self):
failures = []
for msg in self.fail_messages:
msg_id = msg.kinesis.sequenceNumber if self.model else msg.kinesis.sequence_number
# # see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091
if self.model and getattr(msg, "parse_obj", None):
msg_id = msg.kinesis.sequenceNumber
else:
msg_id = msg.kinesis.sequence_number
failures.append({"itemIdentifier": msg_id})
return failures

def _collect_dynamodb_failures(self):
failures = []
for msg in self.fail_messages:
msg_id = msg.dynamodb.SequenceNumber if self.model else msg.dynamodb.sequence_number
# see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091
if self.model and getattr(msg, "parse_obj", None):
msg_id = msg.dynamodb.SequenceNumber
else:
msg_id = msg.dynamodb.sequence_number
failures.append({"itemIdentifier": msg_id})
return failures

Expand All @@ -347,6 +363,17 @@ def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["B
return model.parse_obj(record)
return self._DATA_CLASS_MAPPING[event_type](record)

def _register_model_validation_error_record(self, record: dict):
"""Convert and register failure due to poison pills where model failed validation early"""
# Parser will fail validation if record is a poison pill (malformed input)
# this means we can't collect the message id if we try transforming again
# so we convert into to the equivalent batch type model (e.g., SQS, Kinesis, DynamoDB Stream)
# and downstream we can correctly collect the correct message id identifier and make the failed record available
# see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091
logger.debug("Record cannot be converted to customer's model; converting without model")
failed_record: "EventSourceDataClassTypes" = self._to_batch_type(record=record, event_type=self.event_type)
return self.failure_handler(record=failed_record, exception=sys.exc_info())


class BatchProcessor(BasePartialBatchProcessor): # Keep old name for compatibility
"""Process native partial responses from SQS, Kinesis Data Streams, and DynamoDB.
Expand Down Expand Up @@ -471,14 +498,17 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons
record: dict
A batch record to be processed.
"""
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
data: Optional["BatchTypeModels"] = None
try:
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
if self._handler_accepts_lambda_context:
result = self.handler(record=data, lambda_context=self.lambda_context)
else:
result = self.handler(record=data)

return self.success_handler(record=record, result=result)
except ValidationError:
return self._register_model_validation_error_record(record)
except Exception:
return self.failure_handler(record=data, exception=sys.exc_info())

Expand Down Expand Up @@ -651,14 +681,17 @@ async def _async_process_record(self, record: dict) -> Union[SuccessResponse, Fa
record: dict
A batch record to be processed.
"""
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
data: Optional["BatchTypeModels"] = None
try:
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
if self._handler_accepts_lambda_context:
result = await self.handler(record=data, lambda_context=self.lambda_context)
else:
result = await self.handler(record=data)

return self.success_handler(record=record, result=result)
except ValidationError:
return self._register_model_validation_error_record(record)
except Exception:
return self.failure_handler(record=data, exception=sys.exc_info())

Expand Down
4 changes: 2 additions & 2 deletions aws_lambda_powertools/utilities/parser/models/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
class DynamoDBStreamChangedRecordModel(BaseModel):
ApproximateCreationDateTime: Optional[date]
Keys: Dict[str, Dict[str, Any]]
NewImage: Optional[Union[Dict[str, Any], Type[BaseModel]]]
OldImage: Optional[Union[Dict[str, Any], Type[BaseModel]]]
NewImage: Optional[Union[Dict[str, Any], Type[BaseModel], BaseModel]]
OldImage: Optional[Union[Dict[str, Any], Type[BaseModel], BaseModel]]
SequenceNumber: str
SizeBytes: int
StreamViewType: Literal["NEW_AND_OLD_IMAGES", "KEYS_ONLY", "NEW_IMAGE", "OLD_IMAGE"]
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/utilities/parser/models/kinesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class KinesisDataStreamRecordPayload(BaseModel):
kinesisSchemaVersion: str
partitionKey: str
sequenceNumber: str
data: Union[bytes, Type[BaseModel]] # base64 encoded str is parsed into bytes
data: Union[bytes, Type[BaseModel], BaseModel] # base64 encoded str is parsed into bytes
approximateArrivalTimestamp: float

@validator("data", pre=True, allow_reuse=True)
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/utilities/parser/models/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SqsMsgAttributeModel(BaseModel):
class SqsRecordModel(BaseModel):
messageId: str
receiptHandle: str
body: Union[str, Type[BaseModel]]
body: Union[str, Type[BaseModel], BaseModel]
attributes: SqsAttributesModel
messageAttributes: Dict[str, SqsMsgAttributeModel]
md5OfBody: str
Expand Down
8 changes: 5 additions & 3 deletions aws_lambda_powertools/utilities/parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
import sys
from typing import Any, Dict, Type, TypeVar, Union

from pydantic import BaseModel
from pydantic import BaseModel, Json

# We only need typing_extensions for python versions <3.8
if sys.version_info >= (3, 8):
from typing import Literal # noqa: F401
from typing import Literal
else:
from typing_extensions import Literal # noqa: F401
from typing_extensions import Literal

Model = TypeVar("Model", bound=BaseModel)
EnvelopeModel = TypeVar("EnvelopeModel")
EventParserReturnType = TypeVar("EventParserReturnType")
AnyInheritedModel = Union[Type[BaseModel], BaseModel]
RawDictOrModel = Union[Dict[str, Any], AnyInheritedModel]

__all__ = ["Json", "Literal"]
Empty file.
47 changes: 47 additions & 0 deletions tests/functional/batch/sample_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json
from typing import Dict, Optional

from aws_lambda_powertools.utilities.parser import BaseModel, validator
from aws_lambda_powertools.utilities.parser.models import (
DynamoDBStreamChangedRecordModel,
DynamoDBStreamRecordModel,
KinesisDataStreamRecord,
KinesisDataStreamRecordPayload,
SqsRecordModel,
)
from aws_lambda_powertools.utilities.parser.types import Json, Literal


class Order(BaseModel):
item: dict


class OrderSqs(SqsRecordModel):
body: Json[Order]


class OrderKinesisPayloadRecord(KinesisDataStreamRecordPayload):
data: Json[Order]


class OrderKinesisRecord(KinesisDataStreamRecord):
kinesis: OrderKinesisPayloadRecord


class OrderDynamoDB(BaseModel):
Message: Order

# auto transform json string
# so Pydantic can auto-initialize nested Order model
@validator("Message", pre=True)
def transform_message_to_dict(cls, value: Dict[Literal["S"], str]):
return json.loads(value["S"])


class OrderDynamoDBChangeRecord(DynamoDBStreamChangedRecordModel):
NewImage: Optional[OrderDynamoDB]
OldImage: Optional[OrderDynamoDB]


class OrderDynamoDBRecord(DynamoDBStreamRecordModel):
dynamodb: OrderDynamoDBChangeRecord
Loading