from __future__ import annotations

import sys
import time
from multiprocessing import Queue, get_context
from unittest import TestCase, main

import pytest

from mypy.ipc import IPCClient, IPCServer

CONNECTION_NAME = "dmypy-test-ipc"


def server(msg: str, q: Queue[str]) -> None:
    server = IPCServer(CONNECTION_NAME)
    q.put(server.connection_name)
    data = ""
    while not data:
        with server:
            server.write(msg)
            data = server.read()
    server.cleanup()


def server_multi_message_echo(q: Queue[str]) -> None:
    server = IPCServer(CONNECTION_NAME)
    q.put(server.connection_name)
    data = ""
    with server:
        while data != "quit":
            data = server.read()
            server.write(data)
    server.cleanup()


class IPCTests(TestCase):
    def setUp(self) -> None:
        if sys.platform == "linux":
            # The default "fork" start method is potentially unsafe
            self.ctx = get_context("forkserver")
        else:
            self.ctx = get_context("spawn")

    def test_transaction_large(self) -> None:
        queue: Queue[str] = self.ctx.Queue()
        msg = "t" * 200000  # longer than the max read size of 100_000
        p = self.ctx.Process(target=server, args=(msg, queue), daemon=True)
        p.start()
        connection_name = queue.get()
        with IPCClient(connection_name, timeout=1) as client:
            assert client.read() == msg
            client.write("test")
        queue.close()
        queue.join_thread()
        p.join()

    def test_connect_twice(self) -> None:
        queue: Queue[str] = self.ctx.Queue()
        msg = "this is a test message"
        p = self.ctx.Process(target=server, args=(msg, queue), daemon=True)
        p.start()
        connection_name = queue.get()
        with IPCClient(connection_name, timeout=1) as client:
            assert client.read() == msg
            client.write("")  # don't let the server hang up yet, we want to connect again.

        with IPCClient(connection_name, timeout=1) as client:
            assert client.read() == msg
            client.write("test")
        queue.close()
        queue.join_thread()
        p.join()
        assert p.exitcode == 0

    def test_multiple_messages(self) -> None:
        queue: Queue[str] = self.ctx.Queue()
        p = self.ctx.Process(target=server_multi_message_echo, args=(queue,), daemon=True)
        p.start()
        connection_name = queue.get()
        with IPCClient(connection_name, timeout=1) as client:
            # "foo bar" with extra accents on letters.
            # In UTF-8 encoding so we don't confuse editors opening this file.
            fancy_text = b"f\xcc\xb6o\xcc\xb2\xf0\x9d\x91\x9c \xd0\xb2\xe2\xb7\xa1a\xcc\xb6r\xcc\x93\xcd\x98\xcd\x8c"
            client.write(fancy_text.decode("utf-8"))
            assert client.read() == fancy_text.decode("utf-8")

            client.write("Test with spaces")
            client.write("Test write before reading previous")
            time.sleep(0)  # yield to the server to force reading of all messages by server.
            assert client.read() == "Test with spaces"
            assert client.read() == "Test write before reading previous"

            client.write("quit")
            assert client.read() == "quit"
        queue.close()
        queue.join_thread()
        p.join()
        assert p.exitcode == 0

    # Run test_connect_twice a lot, in the hopes of finding issues.
    # This is really slow, so it is skipped, but can be enabled if
    # needed to debug IPC issues.
    @pytest.mark.skip
    def test_connect_alot(self) -> None:
        t0 = time.time()
        for i in range(1000):
            try:
                print(i, "start")
                self.test_connect_twice()
            finally:
                t1 = time.time()
                print(i, t1 - t0)
                sys.stdout.flush()
                t0 = t1


if __name__ == "__main__":
    main()
