"""Helpers for implementing generic IR to IR transforms."""

from __future__ import annotations

from typing import Final, Optional

from mypyc.ir.ops import (
    Assign,
    AssignMulti,
    BasicBlock,
    Box,
    Branch,
    Call,
    CallC,
    Cast,
    ComparisonOp,
    DecRef,
    Extend,
    FloatComparisonOp,
    FloatNeg,
    FloatOp,
    GetAttr,
    GetElementPtr,
    Goto,
    IncRef,
    InitStatic,
    IntOp,
    KeepAlive,
    LoadAddress,
    LoadErrorValue,
    LoadGlobal,
    LoadLiteral,
    LoadMem,
    LoadStatic,
    MethodCall,
    Op,
    OpVisitor,
    PrimitiveOp,
    RaiseStandardError,
    Return,
    SetAttr,
    SetMem,
    Truncate,
    TupleGet,
    TupleSet,
    Unborrow,
    Unbox,
    Unreachable,
    Value,
)
from mypyc.irbuild.ll_builder import LowLevelIRBuilder


class IRTransform(OpVisitor[Optional[Value]]):
    """Identity transform.

    Subclass and override to perform changes to IR.

    Subclass IRTransform and override any OpVisitor visit_* methods
    that perform any IR changes. The default implementations implement
    an identity transform.

    A visit method can return None to remove ops. In this case the
    transform must ensure that no op uses the original removed op
    as a source after the transform.

    You can retain old BasicBlock and op references in ops. The transform
    will automatically patch these for you as needed.
    """

    def __init__(self, builder: LowLevelIRBuilder) -> None:
        self.builder = builder
        # Subclasses add additional op mappings here. A None value indicates
        # that the op/register is deleted.
        self.op_map: dict[Value, Value | None] = {}

    def transform_blocks(self, blocks: list[BasicBlock]) -> None:
        """Transform basic blocks that represent a single function.

        The result of the transform will be collected at self.builder.blocks.
        """
        block_map: dict[BasicBlock, BasicBlock] = {}
        op_map = self.op_map
        empties = set()
        for block in blocks:
            new_block = BasicBlock()
            block_map[block] = new_block
            self.builder.activate_block(new_block)
            new_block.error_handler = block.error_handler
            for op in block.ops:
                new_op = op.accept(self)
                if new_op is not op:
                    op_map[op] = new_op
            # A transform can produce empty blocks which can be removed.
            if is_empty_block(new_block) and not is_empty_block(block):
                empties.add(new_block)
        self.builder.blocks = [block for block in self.builder.blocks if block not in empties]
        # Update all op/block references to point to the transformed ones.
        patcher = PatchVisitor(op_map, block_map)
        for block in self.builder.blocks:
            for op in block.ops:
                op.accept(patcher)
            if block.error_handler is not None:
                block.error_handler = block_map.get(block.error_handler, block.error_handler)

    def add(self, op: Op) -> Value:
        return self.builder.add(op)

    def visit_goto(self, op: Goto) -> None:
        self.add(op)

    def visit_branch(self, op: Branch) -> None:
        self.add(op)

    def visit_return(self, op: Return) -> None:
        self.add(op)

    def visit_unreachable(self, op: Unreachable) -> None:
        self.add(op)

    def visit_assign(self, op: Assign) -> Value | None:
        if op.src in self.op_map and self.op_map[op.src] is None:
            # Special case: allow removing register initialization assignments
            return None
        return self.add(op)

    def visit_assign_multi(self, op: AssignMulti) -> Value | None:
        return self.add(op)

    def visit_load_error_value(self, op: LoadErrorValue) -> Value | None:
        return self.add(op)

    def visit_load_literal(self, op: LoadLiteral) -> Value | None:
        return self.add(op)

    def visit_get_attr(self, op: GetAttr) -> Value | None:
        return self.add(op)

    def visit_set_attr(self, op: SetAttr) -> Value | None:
        return self.add(op)

    def visit_load_static(self, op: LoadStatic) -> Value | None:
        return self.add(op)

    def visit_init_static(self, op: InitStatic) -> Value | None:
        return self.add(op)

    def visit_tuple_get(self, op: TupleGet) -> Value | None:
        return self.add(op)

    def visit_tuple_set(self, op: TupleSet) -> Value | None:
        return self.add(op)

    def visit_inc_ref(self, op: IncRef) -> Value | None:
        return self.add(op)

    def visit_dec_ref(self, op: DecRef) -> Value | None:
        return self.add(op)

    def visit_call(self, op: Call) -> Value | None:
        return self.add(op)

    def visit_method_call(self, op: MethodCall) -> Value | None:
        return self.add(op)

    def visit_cast(self, op: Cast) -> Value | None:
        return self.add(op)

    def visit_box(self, op: Box) -> Value | None:
        return self.add(op)

    def visit_unbox(self, op: Unbox) -> Value | None:
        return self.add(op)

    def visit_raise_standard_error(self, op: RaiseStandardError) -> Value | None:
        return self.add(op)

    def visit_call_c(self, op: CallC) -> Value | None:
        return self.add(op)

    def visit_primitive_op(self, op: PrimitiveOp) -> Value | None:
        return self.add(op)

    def visit_truncate(self, op: Truncate) -> Value | None:
        return self.add(op)

    def visit_extend(self, op: Extend) -> Value | None:
        return self.add(op)

    def visit_load_global(self, op: LoadGlobal) -> Value | None:
        return self.add(op)

    def visit_int_op(self, op: IntOp) -> Value | None:
        return self.add(op)

    def visit_comparison_op(self, op: ComparisonOp) -> Value | None:
        return self.add(op)

    def visit_float_op(self, op: FloatOp) -> Value | None:
        return self.add(op)

    def visit_float_neg(self, op: FloatNeg) -> Value | None:
        return self.add(op)

    def visit_float_comparison_op(self, op: FloatComparisonOp) -> Value | None:
        return self.add(op)

    def visit_load_mem(self, op: LoadMem) -> Value | None:
        return self.add(op)

    def visit_set_mem(self, op: SetMem) -> Value | None:
        return self.add(op)

    def visit_get_element_ptr(self, op: GetElementPtr) -> Value | None:
        return self.add(op)

    def visit_load_address(self, op: LoadAddress) -> Value | None:
        return self.add(op)

    def visit_keep_alive(self, op: KeepAlive) -> Value | None:
        return self.add(op)

    def visit_unborrow(self, op: Unborrow) -> Value | None:
        return self.add(op)


class PatchVisitor(OpVisitor[None]):
    def __init__(
        self, op_map: dict[Value, Value | None], block_map: dict[BasicBlock, BasicBlock]
    ) -> None:
        self.op_map: Final = op_map
        self.block_map: Final = block_map

    def fix_op(self, op: Value) -> Value:
        new = self.op_map.get(op, op)
        assert new is not None, "use of removed op"
        return new

    def fix_block(self, block: BasicBlock) -> BasicBlock:
        return self.block_map.get(block, block)

    def visit_goto(self, op: Goto) -> None:
        op.label = self.fix_block(op.label)

    def visit_branch(self, op: Branch) -> None:
        op.value = self.fix_op(op.value)
        op.true = self.fix_block(op.true)
        op.false = self.fix_block(op.false)

    def visit_return(self, op: Return) -> None:
        op.value = self.fix_op(op.value)

    def visit_unreachable(self, op: Unreachable) -> None:
        pass

    def visit_assign(self, op: Assign) -> None:
        op.src = self.fix_op(op.src)

    def visit_assign_multi(self, op: AssignMulti) -> None:
        op.src = [self.fix_op(s) for s in op.src]

    def visit_load_error_value(self, op: LoadErrorValue) -> None:
        pass

    def visit_load_literal(self, op: LoadLiteral) -> None:
        pass

    def visit_get_attr(self, op: GetAttr) -> None:
        op.obj = self.fix_op(op.obj)

    def visit_set_attr(self, op: SetAttr) -> None:
        op.obj = self.fix_op(op.obj)
        op.src = self.fix_op(op.src)

    def visit_load_static(self, op: LoadStatic) -> None:
        pass

    def visit_init_static(self, op: InitStatic) -> None:
        op.value = self.fix_op(op.value)

    def visit_tuple_get(self, op: TupleGet) -> None:
        op.src = self.fix_op(op.src)

    def visit_tuple_set(self, op: TupleSet) -> None:
        op.items = [self.fix_op(item) for item in op.items]

    def visit_inc_ref(self, op: IncRef) -> None:
        op.src = self.fix_op(op.src)

    def visit_dec_ref(self, op: DecRef) -> None:
        op.src = self.fix_op(op.src)

    def visit_call(self, op: Call) -> None:
        op.args = [self.fix_op(arg) for arg in op.args]

    def visit_method_call(self, op: MethodCall) -> None:
        op.obj = self.fix_op(op.obj)
        op.args = [self.fix_op(arg) for arg in op.args]

    def visit_cast(self, op: Cast) -> None:
        op.src = self.fix_op(op.src)

    def visit_box(self, op: Box) -> None:
        op.src = self.fix_op(op.src)

    def visit_unbox(self, op: Unbox) -> None:
        op.src = self.fix_op(op.src)

    def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
        if isinstance(op.value, Value):
            op.value = self.fix_op(op.value)

    def visit_call_c(self, op: CallC) -> None:
        op.args = [self.fix_op(arg) for arg in op.args]

    def visit_primitive_op(self, op: PrimitiveOp) -> None:
        op.args = [self.fix_op(arg) for arg in op.args]

    def visit_truncate(self, op: Truncate) -> None:
        op.src = self.fix_op(op.src)

    def visit_extend(self, op: Extend) -> None:
        op.src = self.fix_op(op.src)

    def visit_load_global(self, op: LoadGlobal) -> None:
        pass

    def visit_int_op(self, op: IntOp) -> None:
        op.lhs = self.fix_op(op.lhs)
        op.rhs = self.fix_op(op.rhs)

    def visit_comparison_op(self, op: ComparisonOp) -> None:
        op.lhs = self.fix_op(op.lhs)
        op.rhs = self.fix_op(op.rhs)

    def visit_float_op(self, op: FloatOp) -> None:
        op.lhs = self.fix_op(op.lhs)
        op.rhs = self.fix_op(op.rhs)

    def visit_float_neg(self, op: FloatNeg) -> None:
        op.src = self.fix_op(op.src)

    def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
        op.lhs = self.fix_op(op.lhs)
        op.rhs = self.fix_op(op.rhs)

    def visit_load_mem(self, op: LoadMem) -> None:
        op.src = self.fix_op(op.src)

    def visit_set_mem(self, op: SetMem) -> None:
        op.dest = self.fix_op(op.dest)
        op.src = self.fix_op(op.src)

    def visit_get_element_ptr(self, op: GetElementPtr) -> None:
        op.src = self.fix_op(op.src)

    def visit_load_address(self, op: LoadAddress) -> None:
        if isinstance(op.src, LoadStatic):
            new = self.fix_op(op.src)
            assert isinstance(new, LoadStatic)
            op.src = new

    def visit_keep_alive(self, op: KeepAlive) -> None:
        op.src = [self.fix_op(s) for s in op.src]

    def visit_unborrow(self, op: Unborrow) -> None:
        op.src = self.fix_op(op.src)


def is_empty_block(block: BasicBlock) -> bool:
    return len(block.ops) == 1 and isinstance(block.ops[0], Unreachable)
