"""Bool register elimination optimization.

Example input:

  L1:
    r0 = f()
    b = r0
    goto L3
  L2:
    r1 = g()
    b = r1
    goto L3
  L3:
    if b goto L4 else goto L5

The register b is redundant and we replace the assignments with two copies of
the branch in L3:

  L1:
    r0 = f()
    if r0 goto L4 else goto L5
  L2:
    r1 = g()
    if r1 goto L4 else goto L5

This helps generate simpler IR for tagged integers comparisons, for example.
"""

from __future__ import annotations

from mypyc.ir.func_ir import FuncIR
from mypyc.ir.ops import Assign, BasicBlock, Branch, Goto, Register, Unreachable
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
from mypyc.options import CompilerOptions
from mypyc.transform.ir_transform import IRTransform


def do_flag_elimination(fn: FuncIR, options: CompilerOptions) -> None:
    # Find registers that are used exactly once as source, and in a branch.
    counts: dict[Register, int] = {}
    branches: dict[Register, Branch] = {}
    labels: dict[Register, BasicBlock] = {}
    for block in fn.blocks:
        for i, op in enumerate(block.ops):
            for src in op.sources():
                if isinstance(src, Register):
                    counts[src] = counts.get(src, 0) + 1
            if i == 0 and isinstance(op, Branch) and isinstance(op.value, Register):
                branches[op.value] = op
                labels[op.value] = block

    # Based on these we can find the candidate registers.
    candidates: set[Register] = {
        r for r in branches if counts.get(r, 0) == 1 and r not in fn.arg_regs
    }

    # Remove candidates with invalid assignments.
    for block in fn.blocks:
        for i, op in enumerate(block.ops):
            if isinstance(op, Assign) and op.dest in candidates:
                next_op = block.ops[i + 1]
                if not (isinstance(next_op, Goto) and next_op.label is labels[op.dest]):
                    # Not right
                    candidates.remove(op.dest)

    builder = LowLevelIRBuilder(None, options)
    transform = FlagEliminationTransform(
        builder, {x: y for x, y in branches.items() if x in candidates}
    )
    transform.transform_blocks(fn.blocks)
    fn.blocks = builder.blocks


class FlagEliminationTransform(IRTransform):
    def __init__(self, builder: LowLevelIRBuilder, branch_map: dict[Register, Branch]) -> None:
        super().__init__(builder)
        self.branch_map = branch_map
        self.branches = set(branch_map.values())

    def visit_assign(self, op: Assign) -> None:
        old_branch = self.branch_map.get(op.dest)
        if old_branch:
            # Replace assignment with a copy of the old branch, which is in a
            # separate basic block. The old branch will be deletecd in visit_branch.
            new_branch = Branch(
                op.src,
                old_branch.true,
                old_branch.false,
                old_branch.op,
                old_branch.line,
                rare=old_branch.rare,
            )
            new_branch.negated = old_branch.negated
            new_branch.traceback_entry = old_branch.traceback_entry
            self.add(new_branch)
        else:
            self.add(op)

    def visit_goto(self, op: Goto) -> None:
        # This is a no-op if basic block already terminated
        self.builder.goto(op.label)

    def visit_branch(self, op: Branch) -> None:
        if op in self.branches:
            # This branch is optimized away
            self.add(Unreachable())
        else:
            self.add(op)
