"""Parallel building utilities."""

from __future__ import annotations

import os
import time
import traceback
from math import sqrt
from typing import TYPE_CHECKING

try:
    import multiprocessing

    HAS_MULTIPROCESSING = True
except ImportError:
    HAS_MULTIPROCESSING = False

from sphinx.errors import SphinxParallelError
from sphinx.util import logging

if TYPE_CHECKING:
    from collections.abc import Callable, Sequence
    from typing import Any

logger = logging.getLogger(__name__)

# our parallel functionality only works for the forking Process
parallel_available = HAS_MULTIPROCESSING and os.name == 'posix'


class SerialTasks:
    """Has the same interface as ParallelTasks, but executes tasks directly."""

    def __init__(self, nproc: int = 1) -> None:
        pass

    def add_task(
        self,
        task_func: Callable[[Any], Any] | Callable[[], Any],
        arg: Any = None,
        result_func: Callable[[Any], Any] | None = None,
    ) -> None:
        if arg is not None:
            res = task_func(arg)  # type: ignore[call-arg]
        else:
            res = task_func()  # type: ignore[call-arg]
        if result_func:
            result_func(res)

    def join(self) -> None:
        pass


class ParallelTasks:
    """Executes *nproc* tasks in parallel after forking."""

    def __init__(self, nproc: int) -> None:
        self.nproc = nproc
        # (optional) function performed by each task on the result of main task
        self._result_funcs: dict[int, Callable[[Any, Any], Any]] = {}
        # task arguments
        self._args: dict[int, list[Any] | None] = {}
        # list of subprocesses (both started and waiting)
        self._procs: dict[int, Any] = {}
        # list of receiving pipe connections of running subprocesses
        self._precvs: dict[int, Any] = {}
        # list of receiving pipe connections of waiting subprocesses
        self._precvs_waiting: dict[int, Any] = {}
        # number of working subprocesses
        self._pworking = 0
        # task number of each subprocess
        self._taskid = 0

    def _process(
        self, pipe: Any, func: Callable[[Any], Any] | Callable[[], Any], arg: Any
    ) -> None:
        try:
            collector = logging.LogCollector()
            with collector.collect():
                if arg is None:
                    ret = func()  # type: ignore[call-arg]
                else:
                    ret = func(arg)  # type: ignore[call-arg]
            failed = False
        except BaseException as err:
            failed = True
            errmsg = traceback.format_exception_only(err.__class__, err)[0].strip()
            ret = (errmsg, traceback.format_exc())
        logging.convert_serializable(collector.logs)
        pipe.send((failed, collector.logs, ret))

    def add_task(
        self,
        task_func: Callable[[Any], Any] | Callable[[], Any],
        arg: Any = None,
        result_func: Callable[[Any, Any], Any] | None = None,
    ) -> None:
        tid = self._taskid
        self._taskid += 1
        self._result_funcs[tid] = result_func or (lambda arg, result: None)
        self._args[tid] = arg
        precv, psend = multiprocessing.Pipe(False)
        context: Any = multiprocessing.get_context('fork')
        proc = context.Process(target=self._process, args=(psend, task_func, arg))
        self._procs[tid] = proc
        self._precvs_waiting[tid] = precv
        try:
            self._join_one()
        except Exception:
            # shutdown other child processes on failure
            # (e.g. OSError: Failed to allocate memory)
            self.terminate()

    def join(self) -> None:
        try:
            while self._pworking:
                if not self._join_one():
                    time.sleep(0.02)
        finally:
            # shutdown other child processes on failure
            self.terminate()

    def terminate(self) -> None:
        for tid in list(self._precvs):
            self._procs[tid].terminate()
            self._result_funcs.pop(tid)
            self._procs.pop(tid)
            self._precvs.pop(tid)
            self._pworking -= 1

    def _join_one(self) -> bool:
        joined_any = False
        for tid, pipe in self._precvs.items():
            if pipe.poll():
                exc, logs, result = pipe.recv()
                if exc:
                    raise SphinxParallelError(*result)
                for log in logs:
                    logger.handle(log)
                self._result_funcs.pop(tid)(self._args.pop(tid), result)
                self._procs[tid].join()
                self._precvs.pop(tid)
                self._pworking -= 1
                joined_any = True
                break

        while self._precvs_waiting and self._pworking < self.nproc:
            newtid, newprecv = self._precvs_waiting.popitem()
            self._precvs[newtid] = newprecv
            self._procs[newtid].start()
            self._pworking += 1

        return joined_any


def make_chunks(arguments: Sequence[str], nproc: int, maxbatch: int = 10) -> list[Any]:
    # determine how many documents to read in one go
    nargs = len(arguments)
    chunksize = nargs // nproc
    if chunksize >= maxbatch:
        # try to improve batch size vs. number of batches
        chunksize = int(sqrt(nargs / nproc * maxbatch))
    if chunksize == 0:
        chunksize = 1
    nchunks, rest = divmod(nargs, chunksize)
    if rest:
        nchunks += 1
    # partition documents in "chunks" that will be written by one Process
    return [arguments[i * chunksize : (i + 1) * chunksize] for i in range(nchunks)]
