"""The citation domain."""

from __future__ import annotations

from typing import TYPE_CHECKING, cast

from docutils import nodes

from sphinx.addnodes import pending_xref
from sphinx.domains import Domain
from sphinx.locale import __
from sphinx.transforms import SphinxTransform
from sphinx.util import logging
from sphinx.util.nodes import copy_source_info, make_refnode

if TYPE_CHECKING:
    from collections.abc import Set
    from typing import Any

    from docutils.nodes import Element

    from sphinx.application import Sphinx
    from sphinx.builders import Builder
    from sphinx.environment import BuildEnvironment
    from sphinx.util.typing import ExtensionMetadata


logger = logging.getLogger(__name__)


class CitationDomain(Domain):
    """Domain for citations."""

    name = 'citation'
    label = 'citation'

    dangling_warnings = {
        'ref': 'citation not found: %(target)s',
    }

    @property
    def citations(self) -> dict[str, tuple[str, str, int]]:
        return self.data.setdefault('citations', {})

    @property
    def citation_refs(self) -> dict[str, set[str]]:
        return self.data.setdefault('citation_refs', {})

    def clear_doc(self, docname: str) -> None:
        for key, (fn, _l, _lineno) in list(self.citations.items()):
            if fn == docname:
                del self.citations[key]
        for key, docnames in list(self.citation_refs.items()):
            if docnames == {docname}:
                del self.citation_refs[key]
            elif docname in docnames:
                docnames.remove(docname)

    def merge_domaindata(self, docnames: Set[str], otherdata: dict[str, Any]) -> None:
        # XXX duplicates?
        for key, data in otherdata['citations'].items():
            if data[0] in docnames:
                self.citations[key] = data
        for key, data in otherdata['citation_refs'].items():
            citation_refs = self.citation_refs.setdefault(key, set())
            for docname in data:
                if docname in docnames:
                    citation_refs.add(docname)

    def note_citation(self, node: nodes.citation) -> None:
        label = node[0].astext()
        if label in self.citations:
            path = self.env.doc2path(self.citations[label][0])
            logger.warning(
                __('duplicate citation %s, other instance in %s'),
                label,
                path,
                location=node,
                type='ref',
                subtype='citation',
            )
        self.citations[label] = (node['docname'], node['ids'][0], node.line)  # type: ignore[assignment]

    def note_citation_reference(self, node: pending_xref) -> None:
        docnames = self.citation_refs.setdefault(node['reftarget'], set())
        docnames.add(self.env.docname)

    def check_consistency(self) -> None:
        for name, (docname, _labelid, lineno) in self.citations.items():
            if name not in self.citation_refs:
                logger.warning(
                    __('Citation [%s] is not referenced.'),
                    name,
                    type='ref',
                    subtype='citation',
                    location=(docname, lineno),
                )

    def resolve_xref(
        self,
        env: BuildEnvironment,
        fromdocname: str,
        builder: Builder,
        typ: str,
        target: str,
        node: pending_xref,
        contnode: Element,
    ) -> nodes.reference | None:
        docname, labelid, lineno = self.citations.get(target, ('', '', 0))
        if not docname:
            return None

        return make_refnode(builder, fromdocname, docname, labelid, contnode)

    def resolve_any_xref(
        self,
        env: BuildEnvironment,
        fromdocname: str,
        builder: Builder,
        target: str,
        node: pending_xref,
        contnode: Element,
    ) -> list[tuple[str, nodes.reference]]:
        refnode = self.resolve_xref(
            env, fromdocname, builder, 'ref', target, node, contnode
        )
        if refnode is None:
            return []
        else:
            return [('ref', refnode)]


class CitationDefinitionTransform(SphinxTransform):
    """Mark citation definition labels as not smartquoted."""

    default_priority = 619

    def apply(self, **kwargs: Any) -> None:
        domain = self.env.domains.citation_domain
        for node in self.document.findall(nodes.citation):
            # register citation node to domain
            node['docname'] = self.env.docname
            domain.note_citation(node)

            # mark citation labels as not smartquoted
            label = cast('nodes.label', node[0])
            label['support_smartquotes'] = False


class CitationReferenceTransform(SphinxTransform):
    """Replace citation references by pending_xref nodes before the default
    docutils transform tries to resolve them.
    """

    default_priority = 619

    def apply(self, **kwargs: Any) -> None:
        domain = self.env.domains.citation_domain
        for node in self.document.findall(nodes.citation_reference):
            target = node.astext()
            ref = pending_xref(
                target,
                refdomain='citation',
                reftype='ref',
                reftarget=target,
                refwarn=True,
                support_smartquotes=False,
                ids=node['ids'],
                classes=node.get('classes', []),
            )
            ref += nodes.inline(target, '[%s]' % target)
            copy_source_info(node, ref)
            node.replace_self(ref)

            # register reference node to domain
            domain.note_citation_reference(ref)


def setup(app: Sphinx) -> ExtensionMetadata:
    app.add_domain(CitationDomain)
    app.add_transform(CitationDefinitionTransform)
    app.add_transform(CitationReferenceTransform)

    return {
        'version': 'builtin',
        'env_version': 1,
        'parallel_read_safe': True,
        'parallel_write_safe': True,
    }
