from __future__ import annotations

import gc
import time
from collections.abc import Mapping


class GcLogger:
    """Context manager to log GC stats and overall time."""

    def __enter__(self) -> GcLogger:
        self.gc_start_time: float | None = None
        self.gc_time = 0.0
        self.gc_calls = 0
        self.gc_collected = 0
        self.gc_uncollectable = 0
        gc.callbacks.append(self.gc_callback)
        self.start_time = time.time()
        return self

    def gc_callback(self, phase: str, info: Mapping[str, int]) -> None:
        if phase == "start":
            assert self.gc_start_time is None, "Start phase out of sequence"
            self.gc_start_time = time.time()
        elif phase == "stop":
            assert self.gc_start_time is not None, "Stop phase out of sequence"
            self.gc_calls += 1
            self.gc_time += time.time() - self.gc_start_time
            self.gc_start_time = None
            self.gc_collected += info["collected"]
            self.gc_uncollectable += info["uncollectable"]
        else:
            assert False, f"Unrecognized gc phase ({phase!r})"

    def __exit__(self, *args: object) -> None:
        while self.gc_callback in gc.callbacks:
            gc.callbacks.remove(self.gc_callback)

    def get_stats(self) -> Mapping[str, float]:
        end_time = time.time()
        result = {
            "gc_time": self.gc_time,
            "gc_calls": self.gc_calls,
            "gc_collected": self.gc_collected,
            "gc_uncollectable": self.gc_uncollectable,
            "build_time": end_time - self.start_time,
        }
        return result
