from datetime import datetime, timedelta

from django.db.models import Count, Q
from django.utils import timezone
from rest_framework import viewsets, status
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.pagination import PageNumberPagination
from drf_spectacular.utils import extend_schema, OpenApiParameter

from apps.analytics.models import AccessLog, Hit
from apps.analytics.views import _aggregate, _aggregate_hits
from apps.tools.models import URL

from .models import Channel, Campaign, MediaAsset, Post
from .serializers import (
    ChannelSerializer, CampaignSerializer, MediaAssetSerializer,
    PostListSerializer, PostDetailSerializer,
)
from .services import ensure_tracking_link


def _link_stats(short_codes, post_ids):
    """Click analytics (AccessLog via shortener) + downstream funnel (Hit via UTM).

    Downstream joins on utm_content = dashless post id (see post_attribution_id).
    """
    clicks = (
        AccessLog.objects.filter(resource_type="shorten_url", resource_id__in=list(short_codes))
        if short_codes else AccessLog.objects.none()
    )
    content_ids = [str(p).replace("-", "") for p in post_ids]
    downstream = (
        Hit.objects.filter(utm_content__in=content_ids)
        if content_ids else Hit.objects.none()
    )
    return {"clicks": _aggregate(clicks), "downstream": _aggregate_hits(downstream)}


class StandardPagination(PageNumberPagination):
    page_size = 50
    page_size_query_param = "page_size"
    max_page_size = 500


class ChannelViewSet(viewsets.ModelViewSet):
    rbac_domain = "disseminations"
    queryset = Channel.objects.all()
    serializer_class = ChannelSerializer
    pagination_class = StandardPagination

    def get_queryset(self):
        qs = super().get_queryset()
        platform = self.request.query_params.get("platform")
        active = self.request.query_params.get("active")
        if platform:
            qs = qs.filter(platform=platform)
        if active is not None:
            qs = qs.filter(active=active.lower() in ("1", "true", "yes"))
        return qs

    def perform_create(self, serializer):
        serializer.save(created_by=self.request.user if self.request.user.is_authenticated else None)


class CampaignViewSet(viewsets.ModelViewSet):
    rbac_domain = "disseminations"
    queryset = Campaign.objects.annotate(post_count_anno=Count("posts")).order_by("-start_date", "name")
    serializer_class = CampaignSerializer
    pagination_class = StandardPagination

    def perform_create(self, serializer):
        serializer.save(created_by=self.request.user if self.request.user.is_authenticated else None)

    @action(detail=True, methods=["get"])
    def link_stats(self, request, pk=None):
        campaign = self.get_object()
        post_ids = list(campaign.posts.values_list("id", flat=True))
        codes = list(
            URL.objects.filter(dissemination_posts__campaign=campaign)
            .exclude(short_code="")
            .values_list("short_code", flat=True)
            .distinct()
        )
        return Response(_link_stats(codes, post_ids))


class MediaAssetViewSet(viewsets.ModelViewSet):
    rbac_domain = "disseminations"
    queryset = MediaAsset.objects.all()
    serializer_class = MediaAssetSerializer
    pagination_class = StandardPagination

    def get_queryset(self):
        qs = super().get_queryset()
        kind = self.request.query_params.get("kind")
        if kind:
            qs = qs.filter(kind=kind)
        return qs

    def perform_create(self, serializer):
        file_obj = self.request.FILES.get("file")
        size = file_obj.size if file_obj else 0
        serializer.save(
            uploaded_by=self.request.user if self.request.user.is_authenticated else None,
            size_bytes=size,
        )


class PostViewSet(viewsets.ModelViewSet):
    rbac_domain = "disseminations"
    queryset = (
        Post.objects.select_related("campaign", "tracking_url")
        .prefetch_related("channels", "media")
        .annotate(channel_count_anno=Count("channels", distinct=True))
    )
    pagination_class = StandardPagination

    def get_serializer_class(self):
        if self.action == "list" or self.action == "calendar":
            return PostListSerializer
        return PostDetailSerializer

    def get_queryset(self):
        qs = super().get_queryset()
        params = self.request.query_params
        status_param = params.get("status")
        campaign = params.get("campaign")
        channel = params.get("channel")
        date_from = params.get("date_from")
        date_to = params.get("date_to")
        search = params.get("search")
        if status_param:
            qs = qs.filter(status=status_param)
        if campaign:
            qs = qs.filter(campaign_id=campaign)
        if channel:
            qs = qs.filter(channels__id=channel)
        if date_from:
            qs = qs.filter(scheduled_at__date__gte=date_from)
        if date_to:
            qs = qs.filter(scheduled_at__date__lte=date_to)
        if search:
            qs = qs.filter(Q(title__icontains=search) | Q(body__icontains=search))
        # Deterministic ordering — required for stable pagination.
        return qs.distinct().order_by("-scheduled_at", "-created_at")

    def perform_create(self, serializer):
        serializer.save(author=self.request.user if self.request.user.is_authenticated else None)

    @extend_schema(
        parameters=[
            OpenApiParameter(name="year", type=int, required=True),
            OpenApiParameter(name="month", type=int, required=True, description="1-12"),
        ]
    )
    @action(detail=False, methods=["get"])
    def calendar(self, request):
        """Posts for a given month — for calendar grid view."""
        try:
            year = int(request.query_params.get("year"))
            month = int(request.query_params.get("month"))
        except (TypeError, ValueError):
            return Response({"detail": "year and month query params required"}, status=400)
        start = datetime(year, month, 1)
        if month == 12:
            end = datetime(year + 1, 1, 1)
        else:
            end = datetime(year, month + 1, 1)
        qs = self.get_queryset().filter(
            Q(scheduled_at__gte=start, scheduled_at__lt=end)
            | Q(published_at__gte=start, published_at__lt=end)
        )
        return Response(self.get_serializer(qs, many=True).data)

    @action(detail=True, methods=["post"])
    def schedule(self, request, pk=None):
        post = self.get_object()
        scheduled_at = request.data.get("scheduled_at")
        if not scheduled_at:
            return Response({"detail": "scheduled_at required"}, status=400)
        post.scheduled_at = scheduled_at
        post.status = "scheduled"
        post.save(update_fields=["scheduled_at", "status", "updated_at"])
        ensure_tracking_link(post, request.user)
        return Response(PostDetailSerializer(post, context={"request": request}).data)

    @action(detail=True, methods=["post"])
    def publish_now(self, request, pk=None):
        post = self.get_object()
        post.mark_published()
        ensure_tracking_link(post, request.user)
        return Response(PostDetailSerializer(post, context={"request": request}).data)

    @action(detail=True, methods=["post"])
    def generate_tracking_link(self, request, pk=None):
        post = self.get_object()
        if not post.link_url:
            return Response({"detail": "Post has no link_url to track."}, status=400)
        ensure_tracking_link(post, request.user)
        return Response(PostDetailSerializer(post, context={"request": request}).data)

    @action(detail=True, methods=["get"])
    def link_stats(self, request, pk=None):
        post = self.get_object()
        code = post.tracking_url.short_code if post.tracking_url_id else None
        return Response(_link_stats([code] if code else [], [post.id]))

    @action(detail=True, methods=["post"])
    def archive(self, request, pk=None):
        post = self.get_object()
        post.status = "archived"
        post.save(update_fields=["status", "updated_at"])
        return Response(PostDetailSerializer(post, context={"request": request}).data)

    @action(detail=False, methods=["get"])
    def stats(self, request):
        qs = self.get_queryset()
        by_status = list(qs.values("status").annotate(count=Count("id")).order_by("-count"))
        by_platform = list(
            qs.values("channels__platform")
              .annotate(count=Count("id"))
              .order_by("-count")
        )
        by_campaign = list(
            qs.exclude(campaign__isnull=True)
              .values("campaign__id", "campaign__name", "campaign__color")
              .annotate(count=Count("id"))
              .order_by("-count")[:10]
        )
        upcoming = qs.filter(status="scheduled", scheduled_at__gte=timezone.now()).count()
        past_7d = qs.filter(status="published", published_at__gte=timezone.now() - timedelta(days=7)).count()
        return Response({
            "by_status": by_status,
            "by_platform": [
                {"platform": row["channels__platform"] or "unassigned", "count": row["count"]}
                for row in by_platform
            ],
            "by_campaign": [
                {"id": row["campaign__id"], "name": row["campaign__name"], "color": row["campaign__color"], "count": row["count"]}
                for row in by_campaign
            ],
            "upcoming_scheduled": upcoming,
            "published_last_7d": past_7d,
            "total": qs.count(),
        })
