from __future__ import annotations

from functools import cached_property
from typing import Any

from aws_lambda_powertools.shared.headers_serializer import (
    BaseHeadersSerializer,
    HttpApiHeadersSerializer,
    MultiValueHeadersSerializer,
)
from aws_lambda_powertools.utilities.data_classes.common import (
    BaseProxyEvent,
    BaseRequestContext,
    BaseRequestContextV2,
    CaseInsensitiveDict,
    DictWrapper,
)


class APIGatewayEventAuthorizer(DictWrapper):
    @property
    def claims(self) -> dict[str, Any]:
        return self.get("claims") or {}  # key might exist but can be `null`

    @property
    def scopes(self) -> list[str]:
        return self.get("scopes") or []  # key might exist but can be `null`

    @property
    def principal_id(self) -> str:
        """The principal user identification associated with the token sent by the client and returned from an
        API Gateway Lambda authorizer (formerly known as a custom authorizer)"""
        return self.get("principalId") or ""  # key might exist but can be `null`

    @property
    def integration_latency(self) -> int | None:
        """The authorizer latency in ms."""
        return self.get("integrationLatency")

    def get_context(self) -> dict[str, Any]:
        """Retrieve the authorization context details injected by a Lambda Authorizer.

        Example
        --------

        ```python
        ctx: dict = request_context.authorizer.get_context()

        tenant_id = ctx.get("tenant_id")
        ```

        Returns:
        --------
        dict[str, Any]
            A dictionary containing Lambda authorization context details.
        """
        return self._data


class APIGatewayEventRequestContext(BaseRequestContext):
    @property
    def connected_at(self) -> int | None:
        """The Epoch-formatted connection time. (WebSocket API)"""
        return self.get("connectedAt")

    @property
    def connection_id(self) -> str | None:
        """A unique ID for the connection that can be used to make a callback to the client. (WebSocket API)"""
        return self.get("connectionId")

    @property
    def event_type(self) -> str | None:
        """The event type: `CONNECT`, `MESSAGE`, or `DISCONNECT`. (WebSocket API)"""
        return self.get("eventType")

    @property
    def message_direction(self) -> str | None:
        """Message direction (WebSocket API)"""
        return self.get("messageDirection")

    @property
    def message_id(self) -> str | None:
        """A unique server-side ID for a message. Available only when the `eventType` is `MESSAGE`."""
        return self.get("messageId")

    @property
    def operation_name(self) -> str | None:
        """The name of the operation being performed"""
        return self.get("operationName")

    @property
    def route_key(self) -> str | None:
        """The selected route key."""
        return self.get("routeKey")

    @property
    def authorizer(self) -> APIGatewayEventAuthorizer:
        return APIGatewayEventAuthorizer(self.get("authorizer") or {})


class APIGatewayProxyEvent(BaseProxyEvent):
    """AWS Lambda proxy V1

    Documentation:
    --------------
    - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html
    """

    @property
    def version(self) -> str:
        return self["version"]

    @property
    def resource(self) -> str:
        return self["resource"]

    @property
    def multi_value_headers(self) -> dict[str, list[str]]:
        return CaseInsensitiveDict(self.get("multiValueHeaders"))

    @property
    def resolved_query_string_parameters(self) -> dict[str, list[str]]:
        return self.multi_value_query_string_parameters or super().resolved_query_string_parameters

    @property
    def resolved_headers_field(self) -> dict[str, Any]:
        return self.multi_value_headers or self.headers

    @property
    def request_context(self) -> APIGatewayEventRequestContext:
        return APIGatewayEventRequestContext(self["requestContext"])

    @property
    def path_parameters(self) -> dict[str, str]:
        return self.get("pathParameters") or {}

    @property
    def stage_variables(self) -> dict[str, str]:
        return self.get("stageVariables") or {}

    def header_serializer(self) -> BaseHeadersSerializer:
        return MultiValueHeadersSerializer()


class RequestContextV2AuthorizerIam(DictWrapper):
    @property
    def access_key(self) -> str:
        """The IAM user access key associated with the request."""
        return self.get("accessKey") or ""  # key might exist but can be `null`

    @property
    def account_id(self) -> str:
        """The AWS account ID associated with the request."""
        return self.get("accountId") or ""  # key might exist but can be `null`

    @property
    def caller_id(self) -> str:
        """The principal identifier of the caller making the request."""
        return self.get("callerId") or ""  # key might exist but can be `null`

    def _cognito_identity(self) -> dict:
        return self.get("cognitoIdentity") or {}  # not available in FunctionURL; key might exist but can be `null`

    @property
    def cognito_amr(self) -> list[str]:
        """This represents how the user was authenticated.
        AMR stands for  Authentication Methods References as per the openid spec"""
        return self._cognito_identity().get("amr", [])

    @property
    def cognito_identity_id(self) -> str:
        """The Amazon Cognito identity ID of the caller making the request.
        Available only if the request was signed with Amazon Cognito credentials."""
        return self._cognito_identity().get("identityId", "")

    @property
    def cognito_identity_pool_id(self) -> str:
        """The Amazon Cognito identity pool ID of the caller making the request.
        Available only if the request was signed with Amazon Cognito credentials."""
        return self._cognito_identity().get("identityPoolId") or ""  # key might exist but can be `null`

    @property
    def principal_org_id(self) -> str:
        """The AWS organization ID."""
        return self.get("principalOrgId") or ""  # key might exist but can be `null`

    @property
    def user_arn(self) -> str:
        """The Amazon Resource Name (ARN) of the effective user identified after authentication."""
        return self.get("userArn") or ""  # key might exist but can be `null`

    @property
    def user_id(self) -> str:
        """The IAM user ID of the effective user identified after authentication."""
        return self.get("userId") or ""  # key might exist but can be `null`


class RequestContextV2Authorizer(DictWrapper):
    @property
    def jwt_claim(self) -> dict[str, Any]:
        jwt = self.get("jwt") or {}  # not available in FunctionURL; key might exist but can be `null`
        return jwt.get("claims") or {}  # key might exist but can be `null`

    @property
    def jwt_scopes(self) -> list[str]:
        jwt = self.get("jwt") or {}  # not available in FunctionURL; key might exist but can be `null`
        return jwt.get("scopes", [])

    @property
    def get_lambda(self) -> dict[str, Any]:
        """Lambda authorization context details"""
        return self.get("lambda") or {}  # key might exist but can be `null`

    def get_context(self) -> dict[str, Any]:
        """Retrieve the authorization context details injected by a Lambda Authorizer.

        Example
        --------

        ```python
        ctx: dict = request_context.authorizer.get_context()

        tenant_id = ctx.get("tenant_id")
        ```

        Returns:
        --------
        dict[str, Any]
            A dictionary containing Lambda authorization context details.
        """
        return self.get_lambda

    @property
    def iam(self) -> RequestContextV2AuthorizerIam:
        """IAM authorization details used for making the request."""
        iam = self.get("iam") or {}  # key might exist but can be `null`
        return RequestContextV2AuthorizerIam(iam)


class RequestContextV2(BaseRequestContextV2):
    @property
    def authorizer(self) -> RequestContextV2Authorizer:
        return RequestContextV2Authorizer(self.get("authorizer") or {})


class APIGatewayProxyEventV2(BaseProxyEvent):
    """AWS Lambda proxy V2 event

    Notes:
    -----
    Format 2.0 doesn't have multiValueHeaders or multiValueQueryStringParameters fields. Duplicate headers
    are combined with commas and included in the headers field. Duplicate query strings are combined with
    commas and included in the queryStringParameters field.

    Format 2.0 includes a new cookies field. All cookie headers in the request are combined with commas and
    added to the cookies field. In the response to the client, each cookie becomes a set-cookie header.

    Documentation:
    --------------
    - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html
    """

    @property
    def version(self) -> str:
        return self["version"]

    @property
    def route_key(self) -> str:
        return self["routeKey"]

    @property
    def raw_path(self) -> str:
        return self["rawPath"]

    @property
    def raw_query_string(self) -> str:
        return self["rawQueryString"]

    @property
    def cookies(self) -> list[str]:
        return self.get("cookies") or []

    @property
    def request_context(self) -> RequestContextV2:
        return RequestContextV2(self["requestContext"])

    @property
    def path_parameters(self) -> dict[str, str]:
        return self.get("pathParameters") or {}

    @property
    def stage_variables(self) -> dict[str, str]:
        return self.get("stageVariables") or {}

    @property
    def path(self) -> str:
        stage = self.request_context.stage
        if stage != "$default":
            return self.raw_path[len("/" + stage) :]
        return self.raw_path

    @property
    def http_method(self) -> str:
        """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
        return self.request_context.http.method

    def header_serializer(self):
        return HttpApiHeadersSerializer()

    @cached_property
    def resolved_headers_field(self) -> dict[str, Any]:
        return CaseInsensitiveDict((k, v.split(",") if "," in v else v) for k, v in self.headers.items())