from __future__ import annotations

import sys
import textwrap
from difflib import unified_diff
from typing import TYPE_CHECKING

from docutils import nodes
from docutils.parsers.rst import directives

from sphinx import addnodes
from sphinx.directives import optional_int
from sphinx.locale import __
from sphinx.util import logging
from sphinx.util._lines import parse_line_num_spec
from sphinx.util._pathlib import _StrPath
from sphinx.util.docutils import SphinxDirective

if TYPE_CHECKING:
    import os
    from typing import Any, ClassVar

    from docutils.nodes import Element, Node

    from sphinx.application import Sphinx
    from sphinx.config import Config
    from sphinx.util.typing import ExtensionMetadata, OptionSpec

logger = logging.getLogger(__name__)


class Highlight(SphinxDirective):
    """Directive to set the highlighting language for code blocks, as well
    as the threshold for line numbers.
    """

    has_content = False
    required_arguments = 1
    optional_arguments = 0
    final_argument_whitespace = False
    option_spec: ClassVar[OptionSpec] = {
        'force': directives.flag,
        'linenothreshold': directives.positive_int,
    }

    def run(self) -> list[Node]:
        language = self.arguments[0].strip()
        linenothreshold = self.options.get('linenothreshold', sys.maxsize)
        force = 'force' in self.options

        self.env.current_document.highlight_language = language
        return [
            addnodes.highlightlang(
                lang=language, force=force, linenothreshold=linenothreshold
            )
        ]


def dedent_lines(
    lines: list[str], dedent: int | None, location: tuple[str, int] | None = None
) -> list[str]:
    if dedent is None:
        return textwrap.dedent(''.join(lines)).splitlines(True)

    if any(s[:dedent].strip() for s in lines):
        logger.warning(__('non-whitespace stripped by dedent'), location=location)

    new_lines = []
    for line in lines:
        new_line = line[dedent:]
        if line.endswith('\n') and not new_line:
            new_line = '\n'  # keep CRLF
        new_lines.append(new_line)

    return new_lines


def container_wrapper(
    directive: SphinxDirective, literal_node: Node, caption: str
) -> nodes.container:
    container_node = nodes.container(
        '', literal_block=True, classes=['literal-block-wrapper']
    )
    parsed = directive.parse_text_to_nodes(caption, offset=directive.content_offset)
    node = parsed[0]
    if isinstance(node, nodes.system_message):
        msg = __('Invalid caption: %s') % node.astext()
        raise ValueError(msg)  # NoQA: TRY004
    if isinstance(node, nodes.Element):
        caption_node = nodes.caption(node.rawsource, '', *node.children)
        caption_node.source = literal_node.source
        caption_node.line = literal_node.line
        container_node += caption_node
        container_node += literal_node
        return container_node
    raise RuntimeError  # never reached


class CodeBlock(SphinxDirective):
    """Directive for a code block with special highlighting or line numbering
    settings.
    """

    has_content = True
    required_arguments = 0
    optional_arguments = 1
    final_argument_whitespace = False
    option_spec: ClassVar[OptionSpec] = {
        'force': directives.flag,
        'linenos': directives.flag,
        'dedent': optional_int,
        'lineno-start': int,
        'emphasize-lines': directives.unchanged_required,
        'caption': directives.unchanged_required,
        'class': directives.class_option,
        'name': directives.unchanged,
    }

    def run(self) -> list[Node]:
        document = self.state.document
        code = '\n'.join(self.content)
        location = self.state_machine.get_source_and_line(self.lineno)

        linespec = self.options.get('emphasize-lines')
        if linespec:
            try:
                nlines = len(self.content)
                hl_lines = parse_line_num_spec(linespec, nlines)
                if any(i >= nlines for i in hl_lines):
                    logger.warning(
                        __('line number spec is out of range(1-%d): %r'),
                        nlines,
                        self.options['emphasize-lines'],
                        location=location,
                    )

                hl_lines = [x + 1 for x in hl_lines if x < nlines]
            except ValueError as err:
                return [document.reporter.warning(err, line=self.lineno)]
        else:
            hl_lines = None

        if 'dedent' in self.options:
            location = self.state_machine.get_source_and_line(self.lineno)
            lines = code.splitlines(True)
            lines = dedent_lines(lines, self.options['dedent'], location=location)
            code = ''.join(lines)

        literal: Element = nodes.literal_block(code, code)
        if 'linenos' in self.options or 'lineno-start' in self.options:
            literal['linenos'] = True
        literal['classes'] += self.options.get('class', [])
        literal['force'] = 'force' in self.options
        if self.arguments:
            # highlight language specified
            literal['language'] = self.arguments[0]
        else:
            # no highlight language specified.  Then this directive refers the current
            # highlight setting via ``highlight`` directive or ``highlight_language``
            # configuration.
            literal['language'] = (
                self.env.current_document.highlight_language
                or self.config.highlight_language
            )
        extra_args = literal['highlight_args'] = {}
        if hl_lines is not None:
            extra_args['hl_lines'] = hl_lines
        if 'lineno-start' in self.options:
            extra_args['linenostart'] = self.options['lineno-start']
        self.set_source_info(literal)

        caption = self.options.get('caption')
        if caption:
            try:
                literal = container_wrapper(self, literal, caption)
            except ValueError as exc:
                return [document.reporter.warning(exc, line=self.lineno)]

        # literal will be note_implicit_target that is linked from caption and numref.
        # when options['name'] is provided, it should be primary ID.
        self.add_name(literal)

        return [literal]


class LiteralIncludeReader:
    INVALID_OPTIONS_PAIR = [
        ('lineno-match', 'lineno-start'),
        ('lineno-match', 'append'),
        ('lineno-match', 'prepend'),
        ('start-after', 'start-at'),
        ('end-before', 'end-at'),
        ('diff', 'pyobject'),
        ('diff', 'lineno-start'),
        ('diff', 'lineno-match'),
        ('diff', 'lines'),
        ('diff', 'start-after'),
        ('diff', 'end-before'),
        ('diff', 'start-at'),
        ('diff', 'end-at'),
    ]

    def __init__(
        self, filename: str | os.PathLike[str], options: dict[str, Any], config: Config
    ) -> None:
        self.filename = _StrPath(filename)
        self.options = options
        self.encoding = options.get('encoding', config.source_encoding)
        self.lineno_start = self.options.get('lineno-start', 1)

        self.parse_options()

    def parse_options(self) -> None:
        for option1, option2 in self.INVALID_OPTIONS_PAIR:
            if option1 in self.options and option2 in self.options:
                msg = __('Cannot use both "%s" and "%s" options') % (option1, option2)
                raise ValueError(msg)

    def read_file(
        self, filename: str | os.PathLike[str], location: tuple[str, int] | None = None
    ) -> list[str]:
        filename = _StrPath(filename)
        try:
            with open(filename, encoding=self.encoding, errors='strict') as f:
                text = f.read()
            if 'tab-width' in self.options:
                text = text.expandtabs(self.options['tab-width'])

            return text.splitlines(True)
        except OSError as exc:
            msg = __("Include file '%s' not found or reading it failed") % filename
            raise OSError(msg) from exc
        except UnicodeError as exc:
            msg = __(
                "Encoding %r used for reading included file '%s' seems to "
                'be wrong, try giving an :encoding: option'
            ) % (self.encoding, filename)
            raise UnicodeError(msg) from exc

    def read(self, location: tuple[str, int] | None = None) -> tuple[str, int]:
        if 'diff' in self.options:
            lines = self.show_diff()
        else:
            filters = [
                self.pyobject_filter,
                self.start_filter,
                self.end_filter,
                self.lines_filter,
                self.dedent_filter,
                self.prepend_filter,
                self.append_filter,
            ]
            lines = self.read_file(self.filename, location=location)
            for func in filters:
                lines = func(lines, location=location)

        return ''.join(lines), len(lines)

    def show_diff(self, location: tuple[str, int] | None = None) -> list[str]:
        new_lines = self.read_file(self.filename)
        old_filename = self.options['diff']
        old_lines = self.read_file(old_filename)
        diff = unified_diff(old_lines, new_lines, str(old_filename), str(self.filename))
        return list(diff)

    def pyobject_filter(
        self, lines: list[str], location: tuple[str, int] | None = None
    ) -> list[str]:
        pyobject = self.options.get('pyobject')
        if pyobject:
            from sphinx.pycode import ModuleAnalyzer

            analyzer = ModuleAnalyzer.for_file(self.filename, '')
            tags = analyzer.find_tags()
            if pyobject not in tags:
                msg = __('Object named %r not found in include file %r') % (
                    pyobject,
                    self.filename,
                )
                raise ValueError(msg)
            start = tags[pyobject][1]
            end = tags[pyobject][2]
            lines = lines[start - 1 : end]
            if 'lineno-match' in self.options:
                self.lineno_start = start

        return lines

    def lines_filter(
        self, lines: list[str], location: tuple[str, int] | None = None
    ) -> list[str]:
        linespec = self.options.get('lines')
        if linespec:
            linelist = parse_line_num_spec(linespec, len(lines))
            if any(i >= len(lines) for i in linelist):
                logger.warning(
                    __('line number spec is out of range(1-%d): %r'),
                    len(lines),
                    linespec,
                    location=location,
                )

            if 'lineno-match' in self.options:
                # make sure the line list is not "disjoint".
                first = linelist[0]
                if all(first + i == n for i, n in enumerate(linelist)):
                    self.lineno_start += linelist[0]
                else:
                    msg = __('Cannot use "lineno-match" with a disjoint set of "lines"')
                    raise ValueError(msg)

            lines = [lines[n] for n in linelist if n < len(lines)]
            if not lines:
                msg = __('Line spec %r: no lines pulled from include file %r') % (
                    linespec,
                    self.filename,
                )
                raise ValueError(msg)

        return lines

    def start_filter(
        self, lines: list[str], location: tuple[str, int] | None = None
    ) -> list[str]:
        if 'start-at' in self.options:
            start = self.options.get('start-at')
            inclusive = False
        elif 'start-after' in self.options:
            start = self.options.get('start-after')
            inclusive = True
        else:
            start = None

        if start:
            for lineno, line in enumerate(lines):
                if start in line:
                    if inclusive:
                        if 'lineno-match' in self.options:
                            self.lineno_start += lineno + 1

                        return lines[lineno + 1 :]
                    else:
                        if 'lineno-match' in self.options:
                            self.lineno_start += lineno

                        return lines[lineno:]

            if inclusive is True:
                raise ValueError('start-after pattern not found: %s' % start)
            else:
                raise ValueError('start-at pattern not found: %s' % start)

        return lines

    def end_filter(
        self, lines: list[str], location: tuple[str, int] | None = None
    ) -> list[str]:
        if 'end-at' in self.options:
            end = self.options.get('end-at')
            inclusive = True
        elif 'end-before' in self.options:
            end = self.options.get('end-before')
            inclusive = False
        else:
            end = None

        if end:
            for lineno, line in enumerate(lines):
                if end in line:
                    if inclusive:
                        return lines[: lineno + 1]
                    else:
                        if lineno == 0:
                            pass  # end-before ignores first line
                        else:
                            return lines[:lineno]
            if inclusive is True:
                raise ValueError('end-at pattern not found: %s' % end)
            else:
                raise ValueError('end-before pattern not found: %s' % end)

        return lines

    def prepend_filter(
        self, lines: list[str], location: tuple[str, int] | None = None
    ) -> list[str]:
        prepend = self.options.get('prepend')
        if prepend:
            lines.insert(0, prepend + '\n')

        return lines

    def append_filter(
        self, lines: list[str], location: tuple[str, int] | None = None
    ) -> list[str]:
        append = self.options.get('append')
        if append:
            lines.append(append + '\n')

        return lines

    def dedent_filter(
        self, lines: list[str], location: tuple[str, int] | None = None
    ) -> list[str]:
        if 'dedent' in self.options:
            return dedent_lines(lines, self.options.get('dedent'), location=location)
        else:
            return lines


class LiteralInclude(SphinxDirective):
    """Like ``.. include:: :literal:``, but only warns if the include file is
    not found, and does not raise errors.  Also has several options for
    selecting what to include.
    """

    has_content = False
    required_arguments = 1
    optional_arguments = 0
    final_argument_whitespace = True
    option_spec: ClassVar[OptionSpec] = {
        'dedent': optional_int,
        'linenos': directives.flag,
        'lineno-start': int,
        'lineno-match': directives.flag,
        'tab-width': int,
        'language': directives.unchanged_required,
        'force': directives.flag,
        'encoding': directives.encoding,
        'pyobject': directives.unchanged_required,
        'lines': directives.unchanged_required,
        'start-after': directives.unchanged_required,
        'end-before': directives.unchanged_required,
        'start-at': directives.unchanged_required,
        'end-at': directives.unchanged_required,
        'prepend': directives.unchanged_required,
        'append': directives.unchanged_required,
        'emphasize-lines': directives.unchanged_required,
        'caption': directives.unchanged,
        'class': directives.class_option,
        'name': directives.unchanged,
        'diff': directives.unchanged_required,
    }

    def run(self) -> list[Node]:
        document = self.state.document
        if not document.settings.file_insertion_enabled:
            return [
                document.reporter.warning('File insertion disabled', line=self.lineno)
            ]
        # convert options['diff'] to absolute path
        if 'diff' in self.options:
            _, path = self.env.relfn2path(self.options['diff'])
            self.options['diff'] = path

        try:
            location = self.state_machine.get_source_and_line(self.lineno)
            rel_filename, filename = self.env.relfn2path(self.arguments[0])
            self.env.note_dependency(rel_filename)

            reader = LiteralIncludeReader(filename, self.options, self.config)
            text, lines = reader.read(location=location)

            retnode: Element = nodes.literal_block(text, text, source=filename)
            retnode['force'] = 'force' in self.options
            self.set_source_info(retnode)
            if self.options.get('diff'):  # if diff is set, set udiff
                retnode['language'] = 'udiff'
            elif 'language' in self.options:
                retnode['language'] = self.options['language']
            if (
                'linenos' in self.options
                or 'lineno-start' in self.options
                or 'lineno-match' in self.options
            ):
                retnode['linenos'] = True
            retnode['classes'] += self.options.get('class', [])
            extra_args = retnode['highlight_args'] = {}
            if 'emphasize-lines' in self.options:
                hl_lines = parse_line_num_spec(self.options['emphasize-lines'], lines)
                if any(i >= lines for i in hl_lines):
                    logger.warning(
                        __('line number spec is out of range(1-%d): %r'),
                        lines,
                        self.options['emphasize-lines'],
                        location=location,
                    )
                extra_args['hl_lines'] = [x + 1 for x in hl_lines if x < lines]
            extra_args['linenostart'] = reader.lineno_start

            if 'caption' in self.options:
                caption = self.options['caption'] or self.arguments[0]
                retnode = container_wrapper(self, retnode, caption)

            # retnode will be note_implicit_target that is linked from caption and numref.
            # when options['name'] is provided, it should be primary ID.
            self.add_name(retnode)

            return [retnode]
        except Exception as exc:
            return [document.reporter.warning(exc, line=self.lineno)]


def setup(app: Sphinx) -> ExtensionMetadata:
    directives.register_directive('highlight', Highlight)
    directives.register_directive('code-block', CodeBlock)
    directives.register_directive('sourcecode', CodeBlock)
    directives.register_directive('literalinclude', LiteralInclude)

    return {
        'version': 'builtin',
        'parallel_read_safe': True,
        'parallel_write_safe': True,
    }
