"""Check Python modules and C API for coverage.

Mostly written by Josip Dzolonga for the Google Highly Open Participation
contest.
"""

from __future__ import annotations

import glob
import inspect
import os.path
import pickle
import pkgutil
import re
import sys
from importlib import import_module
from typing import TYPE_CHECKING

import sphinx
from sphinx._cli.util.colour import red
from sphinx.builders import Builder
from sphinx.locale import __
from sphinx.util import logging
from sphinx.util.inspect import safe_getattr

if TYPE_CHECKING:
    from collections.abc import Iterable, Iterator, Sequence, Set
    from typing import IO, Any, TextIO

    from sphinx.application import Sphinx
    from sphinx.util.typing import ExtensionMetadata

logger = logging.getLogger(__name__)


# utility
def write_header(f: IO[str], text: str, char: str = '-') -> None:
    f.write(text + '\n')
    f.write(char * len(text) + '\n\n')


def compile_regex_list(name: str, exps: str) -> list[re.Pattern[str]]:
    lst = []
    for exp in exps:
        try:
            lst.append(re.compile(exp))
        except Exception:
            logger.warning(__('invalid regex %r in %s'), exp, name)
    return lst


def _write_table(table: list[list[str]]) -> Iterator[str]:
    sizes = [max(len(x[column]) for x in table) + 1 for column in range(len(table[0]))]

    yield _add_line(sizes, '-')
    yield from _add_row(sizes, table[0], '=')

    for row in table[1:]:
        yield from _add_row(sizes, row, '-')


def _add_line(sizes: list[int], separator: str) -> str:
    return '+' + ''.join((separator * (size + 1)) + '+' for size in sizes)


def _add_row(
    col_widths: list[int], columns: list[str], separator: str
) -> Iterator[str]:
    row = ''.join(f'| {column: <{col_widths[i]}}' for i, column in enumerate(columns))
    yield f'{row}|'
    yield _add_line(col_widths, separator)


def _load_modules(
    mod_name: str, ignored_module_exps: Iterable[re.Pattern[str]]
) -> Set[str]:
    """Recursively load all submodules.

    :param mod_name: The name of a module to load submodules for.
    :param ignored_module_exps: A list of regexes for modules to ignore.
    :returns: A set of modules names including the provided module name,
        ``mod_name``
    :raises ImportError: If the module indicated by ``mod_name`` could not be
        loaded.
    """
    if any(exp.match(mod_name) for exp in ignored_module_exps):
        return set()

    # This can raise an exception, which must be handled by the caller.
    mod = import_module(mod_name)
    modules = {mod_name}
    if mod.__spec__ is None:
        return modules

    search_locations = mod.__spec__.submodule_search_locations
    for _, sub_mod_name, sub_mod_ispkg in pkgutil.iter_modules(search_locations):
        if sub_mod_name == '__main__':
            continue

        if sub_mod_ispkg:
            modules |= _load_modules(f'{mod_name}.{sub_mod_name}', ignored_module_exps)
        else:
            if any(exp.match(sub_mod_name) for exp in ignored_module_exps):
                continue
            modules.add(f'{mod_name}.{sub_mod_name}')

    return modules


def _determine_py_coverage_modules(
    coverage_modules: Sequence[str],
    seen_modules: Set[str],
    ignored_module_exps: Iterable[re.Pattern[str]],
    py_undoc: dict[str, dict[str, Any]],
) -> list[str]:
    """Return a sorted list of modules to check for coverage.

    Figure out which of the two operating modes to use:

    - If 'coverage_modules' is not specified, we check coverage for all modules
      seen in the documentation tree. Any objects found in these modules that are
      not documented will be noted. This will therefore only identify missing
      objects, but it requires no additional configuration.

    - If 'coverage_modules' is specified, we check coverage for all modules
      specified in this configuration value. Any objects found in these modules
      that are not documented will be noted. In addition, any objects from other
      modules that are documented will be noted. This will therefore identify both
      missing modules and missing objects, but it requires manual configuration.
    """
    if not coverage_modules:
        return sorted(seen_modules)

    modules: set[str] = set()
    for mod_name in coverage_modules:
        try:
            modules |= _load_modules(mod_name, ignored_module_exps)
        except ImportError as err:
            # TODO(stephenfin): Define a subtype for all logs in this module
            logger.warning(__('module %s could not be imported: %s'), mod_name, err)
            py_undoc[mod_name] = {'error': err}
            continue

    # if there are additional modules then we warn but continue scanning
    if additional_modules := seen_modules - modules:
        logger.warning(
            __(
                'the following modules are documented but were not specified '
                'in coverage_modules: %s'
            ),
            ', '.join(additional_modules),
        )

    # likewise, if there are missing modules we warn but continue scanning
    if missing_modules := modules - seen_modules:
        logger.warning(
            __(
                'the following modules are specified in coverage_modules '
                'but were not documented'
            ),
            ', '.join(missing_modules),
        )

    return sorted(modules)


class CoverageBuilder(Builder):
    """Evaluates coverage of code in the documentation."""

    name = 'coverage'
    epilog = __(
        'Testing of coverage in the sources finished, look at the '
        'results in %(outdir)s{sep}python.txt.'
    ).format(sep=os.path.sep)

    def init(self) -> None:
        self.c_sourcefiles: list[str] = []
        for pattern in self.config.coverage_c_path:
            pattern = self.srcdir / pattern
            self.c_sourcefiles.extend(glob.glob(str(pattern)))  # NoQA: PTH207

        self.c_regexes: list[tuple[str, re.Pattern[str]]] = []
        for name, exp in self.config.coverage_c_regexes.items():
            try:
                self.c_regexes.append((name, re.compile(exp)))
            except Exception:
                logger.warning(__('invalid regex %r in coverage_c_regexes'), exp)

        self.c_ignorexps: dict[str, list[re.Pattern[str]]] = {
            name: compile_regex_list('coverage_ignore_c_items', exps)
            for name, exps in self.config.coverage_ignore_c_items.items()
        }
        self.mod_ignorexps = compile_regex_list(
            'coverage_ignore_modules', self.config.coverage_ignore_modules
        )
        self.cls_ignorexps = compile_regex_list(
            'coverage_ignore_classes', self.config.coverage_ignore_classes
        )
        self.fun_ignorexps = compile_regex_list(
            'coverage_ignore_functions', self.config.coverage_ignore_functions
        )
        self.py_ignorexps = compile_regex_list(
            'coverage_ignore_pyobjects', self.config.coverage_ignore_pyobjects
        )

    def get_outdated_docs(self) -> str:
        return 'coverage overview'

    def write_documents(self, _docnames: Set[str]) -> None:
        self.py_undoc: dict[str, dict[str, Any]] = {}
        self.py_undocumented: dict[str, Set[str]] = {}
        self.py_documented: dict[str, Set[str]] = {}
        self.build_py_coverage()
        self.write_py_coverage()

        self.c_undoc: dict[str, Set[tuple[str, str]]] = {}
        self.build_c_coverage()
        self.write_c_coverage()

    def build_c_coverage(self) -> None:
        c_objects = {}
        for obj in self.env.domains.c_domain.get_objects():
            c_objects[obj[2]] = obj[1]
        for filename in self.c_sourcefiles:
            undoc: set[tuple[str, str]] = set()
            with open(filename, encoding='utf-8') as f:
                for line in f:
                    for key, regex in self.c_regexes:
                        match = regex.match(line)
                        if match:
                            name = match.groups()[0]
                            if key not in c_objects:
                                undoc.add((key, name))
                                continue

                            if name not in c_objects[key]:
                                for exp in self.c_ignorexps.get(key, ()):
                                    if exp.match(name):
                                        break
                                else:
                                    undoc.add((key, name))
                            continue
            if undoc:
                self.c_undoc[filename] = undoc

    def write_c_coverage(self) -> None:
        output_file = self.outdir / 'c.txt'
        with open(output_file, 'w', encoding='utf-8') as op:
            if self.config.coverage_write_headline:
                write_header(op, 'Undocumented C API elements', '=')
            op.write('\n')

            for filename, undoc in self.c_undoc.items():
                write_header(op, filename)
                for typ, name in sorted(undoc):
                    op.write(f' * {name:<50} [{typ:>9}]\n')
                    if self.config.coverage_show_missing_items:
                        if self.app.quiet:
                            logger.warning(
                                __('undocumented c api: %s [%s] in file %s'),
                                name,
                                typ,
                                filename,
                            )
                        else:
                            logger.info(
                                red('undocumented  ')  # NoQA: G003
                                + f'c   api       {f"{name} [{typ:>9}]":<30}'
                                + red(' - in file ')
                                + filename
                            )
                op.write('\n')

    def ignore_pyobj(self, full_name: str) -> bool:
        return any(exp.search(full_name) for exp in self.py_ignorexps)

    def build_py_coverage(self) -> None:
        seen_objects = frozenset(self.env.domaindata['py']['objects'])
        seen_modules = frozenset(self.env.domaindata['py']['modules'])

        skip_undoc = self.config.coverage_skip_undoc_in_source

        modules = _determine_py_coverage_modules(
            self.config.coverage_modules,
            seen_modules,
            self.mod_ignorexps,
            self.py_undoc,
        )
        for mod_name in modules:
            ignore = False
            for exp in self.mod_ignorexps:
                if exp.match(mod_name):
                    ignore = True
                    break
            if ignore or self.ignore_pyobj(mod_name):
                continue

            try:
                mod = import_module(mod_name)
            except ImportError as err:
                logger.warning(__('module %s could not be imported: %s'), mod_name, err)
                self.py_undoc[mod_name] = {'error': err}
                continue

            documented_objects: set[str] = set()
            undocumented_objects: set[str] = set()

            funcs = []
            classes: dict[str, list[str]] = {}

            for name, obj in inspect.getmembers(mod):
                # diverse module attributes are ignored:
                if name[0] == '_':
                    # begins in an underscore
                    continue
                if not hasattr(obj, '__module__'):
                    # cannot be attributed to a module
                    continue
                if obj.__module__ != mod_name:
                    # is not defined in this module
                    continue

                full_name = f'{mod_name}.{name}'
                if self.ignore_pyobj(full_name):
                    continue

                if inspect.isfunction(obj):
                    if full_name not in seen_objects:
                        for exp in self.fun_ignorexps:
                            if exp.match(name):
                                break
                        else:
                            if skip_undoc and not obj.__doc__:
                                continue
                            funcs.append(name)
                            undocumented_objects.add(full_name)
                    else:
                        documented_objects.add(full_name)
                elif inspect.isclass(obj):
                    for exp in self.cls_ignorexps:
                        if exp.match(name):
                            break
                    else:
                        if full_name not in seen_objects:
                            if skip_undoc and not obj.__doc__:
                                continue
                            # not documented at all
                            classes[name] = []
                            continue

                        attrs: list[str] = []

                        for attr_name in dir(obj):
                            if attr_name not in obj.__dict__:
                                continue
                            try:
                                attr = safe_getattr(obj, attr_name)
                            except AttributeError:
                                continue
                            if not (inspect.ismethod(attr) or inspect.isfunction(attr)):
                                continue
                            if attr_name[0] == '_':
                                # starts with an underscore, ignore it
                                continue
                            if skip_undoc and not attr.__doc__:
                                # skip methods without docstring if wished
                                continue
                            full_attr_name = f'{full_name}.{attr_name}'
                            if self.ignore_pyobj(full_attr_name):
                                continue
                            if full_attr_name not in seen_objects:
                                attrs.append(attr_name)
                                undocumented_objects.add(full_attr_name)
                            else:
                                documented_objects.add(full_attr_name)

                        if attrs:
                            # some attributes are undocumented
                            classes[name] = attrs

            self.py_undoc[mod_name] = {'funcs': funcs, 'classes': classes}
            self.py_undocumented[mod_name] = undocumented_objects
            self.py_documented[mod_name] = documented_objects

    def _write_py_statistics(self, op: TextIO) -> None:
        """Outputs the table of ``op``."""
        all_modules = frozenset(self.py_documented.keys() | self.py_undocumented.keys())
        all_objects: Set[str] = set()
        all_documented_objects: Set[str] = set()
        for module in all_modules:
            all_objects |= self.py_documented[module] | self.py_undocumented[module]
            all_documented_objects |= self.py_documented[module]

        # prepare tabular
        table = [['Module', 'Coverage', 'Undocumented']]
        for module in sorted(all_modules):
            module_objects = self.py_documented[module] | self.py_undocumented[module]
            if len(module_objects):
                value = 100.0 * len(self.py_documented[module]) / len(module_objects)
            else:
                value = 100.0

            table.append([
                module,
                f'{value:.2f}%',
                str(len(self.py_undocumented[module])),
            ])

        if all_objects:
            table.append([
                'TOTAL',
                f'{100 * len(all_documented_objects) / len(all_objects):.2f}%',
                f'{len(all_objects) - len(all_documented_objects)}',
            ])
        else:
            table.append(['TOTAL', '100', '0'])

        op.writelines(f'{line}\n' for line in _write_table(table))

    def write_py_coverage(self) -> None:
        output_file = self.outdir / 'python.txt'
        failed = []
        with open(output_file, 'w', encoding='utf-8') as op:
            if self.config.coverage_write_headline:
                write_header(op, 'Undocumented Python objects', '=')

            if self.config.coverage_statistics_to_stdout:
                self._write_py_statistics(sys.stdout)

            if self.config.coverage_statistics_to_report:
                write_header(op, 'Statistics')
                self._write_py_statistics(op)
                op.write('\n')

            keys = sorted(self.py_undoc.keys())
            for name in keys:
                undoc = self.py_undoc[name]
                if 'error' in undoc:
                    failed.append((name, undoc['error']))
                else:
                    if not undoc['classes'] and not undoc['funcs']:
                        continue

                    write_header(op, name)
                    if undoc['funcs']:
                        op.write('Functions:\n')
                        op.writelines(f' * {x}\n' for x in undoc['funcs'])
                        if self.config.coverage_show_missing_items:
                            if self.app.quiet:
                                for func in undoc['funcs']:
                                    logger.warning(
                                        __('undocumented python function: %s :: %s'),
                                        name,
                                        func,
                                    )
                            else:
                                for func in undoc['funcs']:
                                    logger.info(
                                        red('undocumented  ')  # NoQA: G003
                                        + f'py  function  {func:<30}'
                                        + red(' - in module ')
                                        + name
                                    )
                        op.write('\n')
                    if undoc['classes']:
                        op.write('Classes:\n')
                        for class_name, methods in sorted(undoc['classes'].items()):
                            if not methods:
                                op.write(f' * {class_name}\n')
                                if self.config.coverage_show_missing_items:
                                    if self.app.quiet:
                                        logger.warning(
                                            __('undocumented python class: %s :: %s'),
                                            name,
                                            class_name,
                                        )
                                    else:
                                        logger.info(
                                            red('undocumented  ')  # NoQA: G003
                                            + f'py  class     {class_name:<30}'
                                            + red(' - in module ')
                                            + name
                                        )
                            else:
                                op.write(f' * {class_name} -- missing methods:\n\n')
                                op.writelines(f'   - {x}\n' for x in methods)
                                if self.config.coverage_show_missing_items:
                                    if self.app.quiet:
                                        for meth in methods:
                                            logger.warning(
                                                __(
                                                    'undocumented python method:'
                                                    ' %s :: %s :: %s'
                                                ),
                                                name,
                                                class_name,
                                                meth,
                                            )
                                    else:
                                        for meth in methods:
                                            logger.info(
                                                red('undocumented  ')  # NoQA: G003
                                                + f'py  method    {f"{class_name}.{meth}":<30}'
                                                + red(' - in module ')
                                                + name
                                            )
                        op.write('\n')

            if failed:
                write_header(op, 'Modules that failed to import')
                op.writelines(f' * {name} -- {err}\n' for name, err in failed)

    def finish(self) -> None:
        # dump the coverage data to a pickle file too
        picklepath = self.outdir / 'undoc.pickle'
        with open(picklepath, 'wb') as dumpfile:
            pickle.dump(
                (self.py_undoc, self.c_undoc, self.py_undocumented, self.py_documented),
                dumpfile,
            )


def setup(app: Sphinx) -> ExtensionMetadata:
    app.add_builder(CoverageBuilder)
    app.add_config_value('coverage_modules', (), '', types=frozenset({list, tuple}))
    app.add_config_value(
        'coverage_ignore_modules', [], '', types=frozenset({list, tuple})
    )
    app.add_config_value(
        'coverage_ignore_functions', [], '', types=frozenset({list, tuple})
    )
    app.add_config_value(
        'coverage_ignore_classes', [], '', types=frozenset({list, tuple})
    )
    app.add_config_value(
        'coverage_ignore_pyobjects', [], '', types=frozenset({list, tuple})
    )
    app.add_config_value('coverage_c_path', [], '', types=frozenset({list, tuple}))
    app.add_config_value('coverage_c_regexes', {}, '', types=frozenset({dict}))
    app.add_config_value('coverage_ignore_c_items', {}, '', types=frozenset({dict}))
    app.add_config_value('coverage_write_headline', True, '', types=frozenset({bool}))
    app.add_config_value(
        'coverage_statistics_to_report', True, '', types=frozenset({bool})
    )
    app.add_config_value(
        'coverage_statistics_to_stdout', True, '', types=frozenset({bool})
    )
    app.add_config_value(
        'coverage_skip_undoc_in_source', False, '', types=frozenset({bool})
    )
    app.add_config_value(
        'coverage_show_missing_items', False, '', types=frozenset({bool})
    )
    return {
        'version': sphinx.__display_version__,
        'parallel_read_safe': True,
    }
