from __future__ import annotations

import asyncio
from typing import Any, Awaitable, Callable, DefaultDict, Sequence, Union, cast

from channels.db import database_sync_to_async as _database_sync_to_async
from django.db.models.base import Model
from django.db.models.query import QuerySet
from idom import use_callback, use_ref
from idom.backend.types import Location
from idom.core.hooks import Context, create_context, use_context, use_effect, use_state

from django_idom.types import IdomWebsocket, Mutation, Query, _Params, _Result


database_sync_to_async = cast(
    Callable[..., Callable[..., Awaitable[Any]]],
    _database_sync_to_async,
)
WebsocketContext: Context[Union[IdomWebsocket, None]] = create_context(None)
_REFETCH_CALLBACKS: DefaultDict[
    Callable[..., Any], set[Callable[[], None]]
] = DefaultDict(set)


def use_location() -> Location:
    """Get the current route as a string"""
    # TODO: Use the browser's current page, rather than the WS route
    scope = use_scope()
    search = scope["query_string"].decode()
    return Location(scope["path"], f"?{search}" if search else "")


def use_scope() -> dict[str, Any]:
    """Get the current ASGI scope dictionary"""
    return use_websocket().scope


def use_websocket() -> IdomWebsocket:
    """Get the current IdomWebsocket object"""
    websocket = use_context(WebsocketContext)
    if websocket is None:
        raise RuntimeError("No websocket. Are you running with a Django server?")
    return websocket


def use_query(
    query: Callable[_Params, Union[_Result, None]],
    *args: _Params.args,
    **kwargs: _Params.kwargs,
) -> Query[Union[_Result, None]]:
    query_ref = use_ref(query)
    if query_ref.current is not query:
        raise ValueError(f"Query function changed from {query_ref.current} to {query}.")

    should_execute, set_should_execute = use_state(True)
    data, set_data = use_state(cast(Union[_Result, None], None))
    loading, set_loading = use_state(True)
    error, set_error = use_state(cast(Union[Exception, None], None))

    @use_callback
    def refetch() -> None:
        set_should_execute(True)
        set_loading(True)
        set_error(None)

    @use_effect(dependencies=[])
    def add_refetch_callback() -> Callable[[], None]:
        # By tracking callbacks globally, any usage of the query function will be re-run
        # if the user has told a mutation to refetch it.
        _REFETCH_CALLBACKS[query].add(refetch)
        return lambda: _REFETCH_CALLBACKS[query].remove(refetch)

    @use_effect(dependencies=None)
    @database_sync_to_async
    def execute_query() -> None:
        if not should_execute:
            return

        try:
            new_data = query(*args, **kwargs)
            _fetch_deferred(new_data)
        except Exception as e:
            set_data(None)
            set_loading(False)
            set_error(e)
            return
        finally:
            set_should_execute(False)

        set_data(new_data)
        set_loading(False)
        set_error(None)

    return Query(data, loading, error, refetch)


def use_mutation(
    mutate: Callable[_Params, Union[bool, None]],
    refetch: Union[Callable[..., Any], Sequence[Callable[..., Any]]],
) -> Mutation[_Params]:
    loading, set_loading = use_state(False)
    error, set_error = use_state(cast(Union[Exception, None], None))

    @use_callback
    def call(*args: _Params.args, **kwargs: _Params.kwargs) -> None:
        set_loading(True)

        @database_sync_to_async
        def execute_mutation() -> None:
            try:
                should_refetch = mutate(*args, **kwargs)
            except Exception as e:
                set_loading(False)
                set_error(e)
            else:
                set_loading(False)
                set_error(None)
                if should_refetch is not False:
                    for query in (refetch,) if callable(refetch) else refetch:
                        for callback in _REFETCH_CALLBACKS.get(query) or ():
                            callback()

        asyncio.ensure_future(execute_mutation())

    @use_callback
    def reset() -> None:
        set_loading(False)
        set_error(None)

    return Mutation(call, loading, error, reset)


def _fetch_deferred(data: Any) -> None:
    # https://github.com/typeddjango/django-stubs/issues/704
    if isinstance(data, QuerySet):  # type: ignore[misc]
        for model in data:
            _fetch_deferred_model_fields(model)
    elif isinstance(data, Model):
        _fetch_deferred_model_fields(data)
    else:
        raise ValueError(f"Expected a Model or QuerySet, got {data!r}")


def _fetch_deferred_model_fields(model: Any) -> None:
    for field in model.get_deferred_fields():
        value = getattr(model, field)
        if isinstance(value, Model):
            _fetch_deferred_model_fields(value)