from __future__ import annotations

import pytest

import polars as pl
from polars.exceptions import SQLInterfaceError
from polars.testing import assert_frame_equal


def test_except_intersect() -> None:
    df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]})
    df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]})

    res_e = pl.sql("SELECT x, y, z FROM df1 EXCEPT SELECT * FROM df2", eager=True)
    res_i = pl.sql("SELECT * FROM df1 INTERSECT SELECT x, y, z FROM df2", eager=True)

    assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]
    assert sorted(res_i.rows()) == [(1, 4, 5)]

    res_e = pl.sql("SELECT * FROM df2 EXCEPT TABLE df1", eager=True)
    res_i = pl.sql(
        """
        SELECT * FROM df2
        INTERSECT
        SELECT x::int8, y::int8, z::int8
          FROM (VALUES (1,2,5),(9,3,5),(1,4,5),(1,4,5)) AS df1(x,y,z)
        """,
        eager=True,
    )
    assert sorted(res_e.rows()) == [(1, 2, 7), (9, None, 6)]
    assert sorted(res_i.rows()) == [(1, 4, 5)]

    # check null behaviour of nulls
    with pl.SQLContext(
        tbl1=pl.DataFrame({"x": [2, 9, 1], "y": [2, None, 4]}),
        tbl2=pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4]}),
    ) as ctx:
        res = ctx.execute("SELECT * FROM tbl1 EXCEPT SELECT * FROM tbl2", eager=True)
        assert_frame_equal(pl.DataFrame({"x": [2], "y": [2]}), res)


def test_except_intersect_by_name() -> None:
    df1 = pl.DataFrame(
        {
            "x": [1, 9, 1, 1],
            "y": [2, 3, 4, 4],
            "z": [5, 5, 5, 5],
        }
    )
    df2 = pl.DataFrame(
        {
            "y": [2, None, 4],
            "w": ["?", "!", "%"],
            "z": [7, 6, 5],
            "x": [1, 9, 1],
        }
    )
    res_e = pl.sql(
        "SELECT x, y, z FROM df1 EXCEPT BY NAME SELECT * FROM df2",
        eager=True,
    )
    res_i = pl.sql(
        "SELECT * FROM df1 INTERSECT BY NAME SELECT * FROM df2",
        eager=True,
    )
    assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]
    assert sorted(res_i.rows()) == [(1, 4, 5)]
    assert res_e.columns == ["x", "y", "z"]
    assert res_i.columns == ["x", "y", "z"]


@pytest.mark.parametrize(
    ("op", "op_subtype"),
    [
        ("EXCEPT", "ALL"),
        ("EXCEPT", "ALL BY NAME"),
        ("INTERSECT", "ALL"),
        ("INTERSECT", "ALL BY NAME"),
    ],
)
def test_except_intersect_all_unsupported(op: str, op_subtype: str) -> None:
    df1 = pl.DataFrame({"n": [1, 1, 1, 2, 2, 2, 3]})
    df2 = pl.DataFrame({"n": [1, 1, 2, 2]})

    with pytest.raises(
        SQLInterfaceError,
        match=f"'{op} {op_subtype}' is not supported",
    ):
        pl.sql(f"SELECT * FROM df1 {op} {op_subtype} SELECT * FROM df2")


def test_update_statement_error() -> None:
    df_large = pl.DataFrame(
        {
            "FQDN": ["c.ORG.na", "a.COM.na"],
            "NS1": ["ns1.c.org.na", "ns1.d.net.na"],
            "NS2": ["ns2.c.org.na", "ns2.d.net.na"],
            "NS3": ["ns3.c.org.na", "ns3.d.net.na"],
        }
    )
    df_small = pl.DataFrame(
        {
            "FQDN": ["c.org.na"],
            "NS1": ["ns1.c.org.na|127.0.0.1"],
            "NS2": ["ns2.c.org.na|127.0.0.1"],
            "NS3": ["ns3.c.org.na|127.0.0.1"],
        }
    )

    # Create a context and register the tables
    ctx = pl.SQLContext()
    ctx.register("large", df_large)
    ctx.register("small", df_small)

    with pytest.raises(
        SQLInterfaceError,
        match=r"'UPDATE large SET FQDN = u\.FQDN, NS1 = u\.NS1, NS2 = u\.NS2, NS3 = u\.NS3 FROM u WHERE large\.FQDN = u\.FQDN' operation is currently unsupported",
    ):
        ctx.execute("""
            WITH u AS (
                SELECT
                    small.FQDN,
                    small.NS1,
                    small.NS2,
                    small.NS3
                FROM small
                INNER JOIN large ON small.FQDN = large.FQDN
            )
            UPDATE large
            SET
                FQDN = u.FQDN,
                NS1 = u.NS1,
                NS2 = u.NS2,
                NS3 = u.NS3
            FROM u
            WHERE large.FQDN = u.FQDN
        """)


@pytest.mark.parametrize("op", ["EXCEPT", "INTERSECT", "UNION"])
def test_except_intersect_union_errors(op: str) -> None:
    df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]})
    df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]})

    if op != "UNION":
        with pytest.raises(
            SQLInterfaceError,
            match=f"'{op} ALL' is not supported",
        ):
            pl.sql(
                f"SELECT * FROM df1 {op} ALL SELECT * FROM df2", eager=False
            ).collect()

    with pytest.raises(
        SQLInterfaceError,
        match=f"{op} requires equal number of columns in each table",
    ):
        pl.sql(f"SELECT x FROM df1 {op} SELECT x, y FROM df2", eager=False).collect()


@pytest.mark.parametrize(
    ("cols1", "cols2", "union_subtype", "expected"),
    [
        (
            ["*"],
            ["*"],
            "",
            [(1, "zz"), (2, "yy"), (3, "xx")],
        ),
        (
            ["*"],
            ["frame2.*"],
            "ALL",
            [(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")],
        ),
        (
            ["frame1.*"],
            ["c1", "c2"],
            "DISTINCT",
            [(1, "zz"), (2, "yy"), (3, "xx")],
        ),
        (
            ["*"],
            ["c2", "c1"],
            "ALL BY NAME",
            [(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")],
        ),
        (
            ["c1", "c2"],
            ["c1 AS x1", "c2 AS x2"],
            "",
            [(1, "zz"), (2, "yy"), (3, "xx")],
        ),
        (
            ["c1", "c2"],
            ["c2", "c1"],
            "BY NAME",
            [(1, "zz"), (2, "yy"), (3, "xx")],
        ),
        pytest.param(
            ["c1", "c2"],
            ["c2", "c1"],
            "DISTINCT BY NAME",
            [(1, "zz"), (2, "yy"), (3, "xx")],
        ),
    ],
)
def test_union(
    cols1: list[str],
    cols2: list[str],
    union_subtype: str,
    expected: list[tuple[int, str]],
) -> None:
    with pl.SQLContext(
        frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}),
        frame2=pl.DataFrame({"c1": [2, 3], "c2": ["yy", "xx"]}),
        eager=True,
    ) as ctx:
        query = f"""
            SELECT {", ".join(cols1)} FROM frame1
            UNION {union_subtype}
            SELECT {", ".join(cols2)} FROM frame2
        """
        assert sorted(ctx.execute(query).rows()) == expected


def test_union_nonmatching_colnames() -> None:
    # SQL allows "UNION" (aka: polars `concat`) on column names that don't match;
    # this behaves positionally, with column names coming from the first table
    with pl.SQLContext(
        df1=pl.DataFrame(
            data={"Value": [100, 200], "Tag": ["hello", "foo"]},
            schema_overrides={"Value": pl.Int16},
        ),
        df2=pl.DataFrame(
            data={"Number": [300, 400], "String": ["world", "bar"]},
            schema_overrides={"Number": pl.Int32},
        ),
        eager=True,
    ) as ctx:
        res = ctx.execute(
            query="""
            SELECT u.* FROM (
                SELECT * FROM df1
                UNION
                SELECT * FROM df2
            ) u ORDER BY Value
            """
        )
        assert res.schema == {
            "Value": pl.Int32,
            "Tag": pl.String,
        }
        assert res.rows() == [
            (100, "hello"),
            (200, "foo"),
            (300, "world"),
            (400, "bar"),
        ]
