"""Helpers for AST (Abstract Syntax Tree)."""

from __future__ import annotations

import ast
from typing import TYPE_CHECKING, overload

if TYPE_CHECKING:
    from typing import NoReturn

OPERATORS: dict[type[ast.AST], str] = {
    ast.Add: '+',
    ast.And: 'and',
    ast.BitAnd: '&',
    ast.BitOr: '|',
    ast.BitXor: '^',
    ast.Div: '/',
    ast.FloorDiv: '//',
    ast.Invert: '~',
    ast.LShift: '<<',
    ast.MatMult: '@',
    ast.Mult: '*',
    ast.Mod: '%',
    ast.Not: 'not',
    ast.Pow: '**',
    ast.Or: 'or',
    ast.RShift: '>>',
    ast.Sub: '-',
    ast.UAdd: '+',
    ast.USub: '-',
}


@overload
def unparse(node: None, code: str = '') -> None: ...


@overload
def unparse(node: ast.AST, code: str = '') -> str: ...


def unparse(node: ast.AST | None, code: str = '') -> str | None:
    """Unparse an AST to string."""
    if node is None:
        return None
    elif isinstance(node, str):
        return node
    return _UnparseVisitor(code).visit(node)


# a greatly cut-down version of `ast._Unparser`
class _UnparseVisitor(ast.NodeVisitor):
    def __init__(self, code: str = '') -> None:
        self.code = code

    def _visit_op(self, node: ast.AST) -> str:
        return OPERATORS[node.__class__]

    for _op in OPERATORS:
        locals()[f'visit_{_op.__name__}'] = _visit_op

    def visit_arg(self, node: ast.arg) -> str:
        if node.annotation:
            return f'{node.arg}: {self.visit(node.annotation)}'
        else:
            return node.arg

    def _visit_arg_with_default(self, arg: ast.arg, default: ast.AST | None) -> str:
        """Unparse a single argument to a string."""
        name = self.visit(arg)
        if default:
            if arg.annotation:
                name += ' = %s' % self.visit(default)
            else:
                name += '=%s' % self.visit(default)
        return name

    def visit_arguments(self, node: ast.arguments) -> str:
        defaults: list[ast.expr | None] = list(node.defaults)
        positionals = len(node.args)
        posonlyargs = len(node.posonlyargs)
        positionals += posonlyargs
        for _ in range(len(defaults), positionals):
            defaults.insert(0, None)

        kw_defaults: list[ast.expr | None] = list(node.kw_defaults)
        for _ in range(len(kw_defaults), len(node.kwonlyargs)):
            kw_defaults.insert(0, None)

        args: list[str] = [
            self._visit_arg_with_default(arg, defaults[i])
            for i, arg in enumerate(node.posonlyargs)
        ]

        if node.posonlyargs:
            args.append('/')

        for i, arg in enumerate(node.args):
            args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs]))

        if node.vararg:
            args.append('*' + self.visit(node.vararg))

        if node.kwonlyargs and not node.vararg:
            args.append('*')
        for i, arg in enumerate(node.kwonlyargs):
            args.append(self._visit_arg_with_default(arg, kw_defaults[i]))

        if node.kwarg:
            args.append('**' + self.visit(node.kwarg))

        return ', '.join(args)

    def visit_Attribute(self, node: ast.Attribute) -> str:
        return f'{self.visit(node.value)}.{node.attr}'

    def visit_BinOp(self, node: ast.BinOp) -> str:
        # Special case ``**`` to not have surrounding spaces.
        if isinstance(node.op, ast.Pow):
            return ''.join(map(self.visit, (node.left, node.op, node.right)))
        return ' '.join(map(self.visit, (node.left, node.op, node.right)))

    def visit_BoolOp(self, node: ast.BoolOp) -> str:
        op = ' %s ' % self.visit(node.op)
        return op.join(self.visit(e) for e in node.values)

    def visit_Call(self, node: ast.Call) -> str:
        args = ', '.join(
            [self.visit(e) for e in node.args]
            + [f'{k.arg}={self.visit(k.value)}' for k in node.keywords],
        )
        return f'{self.visit(node.func)}({args})'

    def visit_Constant(self, node: ast.Constant) -> str:
        if node.value is Ellipsis:
            return '...'
        elif isinstance(node.value, int | float | complex):
            if self.code:
                return ast.get_source_segment(self.code, node) or repr(node.value)
            else:
                return repr(node.value)
        else:
            return repr(node.value)

    def visit_Dict(self, node: ast.Dict) -> str:
        keys = (self.visit(k) for k in node.keys if k is not None)
        values = (self.visit(v) for v in node.values)
        items = (k + ': ' + v for k, v in zip(keys, values, strict=True))
        return '{' + ', '.join(items) + '}'

    def visit_Lambda(self, node: ast.Lambda) -> str:
        return 'lambda %s: ...' % self.visit(node.args)

    def visit_List(self, node: ast.List) -> str:
        return '[' + ', '.join(self.visit(e) for e in node.elts) + ']'

    def visit_Name(self, node: ast.Name) -> str:
        return node.id

    def visit_Set(self, node: ast.Set) -> str:
        return '{' + ', '.join(self.visit(e) for e in node.elts) + '}'

    def visit_Slice(self, node: ast.Slice) -> str:
        if not node.lower and not node.upper and not node.step:
            # Empty slice with default values -> [:]
            return ':'

        start = self.visit(node.lower) if node.lower else ''
        stop = self.visit(node.upper) if node.upper else ''
        if not node.step:
            # Default step size -> [start:stop]
            return f'{start}:{stop}'

        step = self.visit(node.step) if node.step else ''
        return f'{start}:{stop}:{step}'

    def visit_Subscript(self, node: ast.Subscript) -> str:
        def is_simple_tuple(value: ast.expr) -> bool:
            return (
                isinstance(value, ast.Tuple)
                and bool(value.elts)
                and not any(isinstance(elt, ast.Starred) for elt in value.elts)
            )

        if is_simple_tuple(node.slice):
            elts = ', '.join(self.visit(e) for e in node.slice.elts)  # type: ignore[attr-defined]
            return f'{self.visit(node.value)}[{elts}]'
        return f'{self.visit(node.value)}[{self.visit(node.slice)}]'

    def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
        # UnaryOp is one of {UAdd, USub, Invert, Not}, which refer to ``+x``,
        # ``-x``, ``~x``, and ``not x``. Only Not needs a space.
        if isinstance(node.op, ast.Not):
            return f'{self.visit(node.op)} {self.visit(node.operand)}'
        return f'{self.visit(node.op)}{self.visit(node.operand)}'

    def visit_Tuple(self, node: ast.Tuple) -> str:
        if len(node.elts) == 0:
            return '()'
        elif len(node.elts) == 1:
            return '(%s,)' % self.visit(node.elts[0])
        else:
            return '(' + ', '.join(self.visit(e) for e in node.elts) + ')'

    def generic_visit(self, node: ast.AST) -> NoReturn:
        raise NotImplementedError('Unable to parse %s object' % type(node).__name__)
