import csv
import itertools
from dataclasses import dataclass
import logging
import subprocess
import typing as t
from bisect import bisect_left
import enum
from pathlib import Path

logger = logging.getLogger(__name__)


class CatGram(enum.Enum):
    NOM = "NOM"
    VERBE = "VER"
    ADJECTIF = "ADJ"
    ADVERBE = "ADV"
    AUXILIAIRE = "AUX"
    ARTICLE = "ART"
    CONJONCTION = "CON"
    LIAISON = "LIA"
    PREPOSITION = "PRE"
    PRONOM = "PRO"
    ONOMATOPEE = "ONO"

    @classmethod
    def parse(cls, val: str) -> "CatGram":
        """Parses a 'catgram' entry"""
        base = val.split(":", maxsplit=1)[0]
        return cls(base)

    def __lt__(self, oth):
        return self.value < oth.value


def match_enum_or_all(val, enum_cls) -> list:
    """The value of the enum corresponding if any; else, all terms of the enum"""
    if val in enum_cls:
        return [enum_cls(val)]
    return list(enum_cls)


class Genre(enum.Enum):
    MASC = "m"
    FEM = "f"


class Nombre(enum.Enum):
    SING = "s"
    PLUR = "p"


class Temps(enum.Enum):
    INFINITIF = "inf"
    PRESENT = "ind:pre"
    FUTUR = "ind:fut"
    IMPARFAIT = "ind:imp"


class Personne(enum.Enum):
    S1 = "1s"
    S2 = "2s"
    S3 = "3s"
    P1 = "1p"
    P2 = "2p"
    P3 = "3p"


@dataclass
class _Mot:
    """Canonical form of a word"""

    mot: str
    cat_gram: CatGram
    freq: float  # occurrences of the canonical form by million words


class Mot(_Mot):
    class Variant:
        pass

    _for_cat_gram: dict[CatGram, t.Type["Mot"]] = {}
    _variants: dict

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._variants = {}

    def accord(self, variant: Variant) -> str:
        return self._variants[variant]

    @classmethod
    def for_cat_gram(cls, cat_gram: CatGram) -> t.Type["Mot"]:
        """The class to use for a word of given CatGram"""
        return cls._for_cat_gram.get(cat_gram, cls)


class Nom(Mot):
    class Variant(t.NamedTuple):
        genre: Genre
        nombre: Nombre


class Verbe(Mot):
    class Variant(t.NamedTuple):
        temps: Temps
        personne: t.Optional[Personne]


Mot._for_cat_gram = {
    CatGram.NOM: Nom,
    CatGram.VERBE: Verbe,
}


class Lexique:
    LEXIQUE_DIR_PATH = Path(__file__).parent.parent / "data/raw/Lexique383"
    LEXIQUE_PATH = LEXIQUE_DIR_PATH / "Lexique383.tsv"

    PRESET_THRESHOLD_BY_CAT: dict[CatGram, int] = {
        CatGram.NOM: 10000,
        CatGram.VERBE: 10000,
        CatGram.ADJECTIF: 10000,
        CatGram.ADVERBE: 10000,
    }

    dataset: list[Mot]

    def __init__(self, dataset):
        self.dataset = dataset

    @classmethod
    def _ensure_uncompressed(cls):
        """Ensures the dataset is uncompressed"""
        if cls.LEXIQUE_DIR_PATH.exists():
            return

        lexique_archive = cls.LEXIQUE_DIR_PATH.with_suffix(".tar.xz")
        if not lexique_archive.exists():
            logging.error("Missing compressed dataset at %s", lexique_archive)
            raise Exception(f"Missing compressed dataset at {lexique_archive}")

        logging.info("Uncompressing dataset")
        subprocess.check_call(
            [
                "tar",
                "-xJf",
                lexique_archive.as_posix(),
                "-C",
                lexique_archive.parent.as_posix(),
            ]
        )

        if not cls.LEXIQUE_DIR_PATH.exists():
            logging.error(
                "Uncompressed dataset still missing at %s after extraction",
                cls.LEXIQUE_DIR_PATH,
            )
            raise Exception(
                f"Uncompressed dataset still missing at {cls.LEXIQUE_DIR_PATH} after extraction"
            )

    @classmethod
    def parse(cls) -> "Lexique":
        out = []
        rows = []
        with cls.LEXIQUE_PATH.open("r") as h:
            reader = csv.DictReader(h, dialect="excel-tab")
            for row in reader:
                if not row["cgram"]:
                    continue
                rows.append(row)

        # First pass: generate canonical forms (lemmes)
        for row in rows:
            if row["lemme"] != row["ortho"]:
                continue
            cat_gram = CatGram.parse(row["cgram"])
            out.append(
                Mot.for_cat_gram(cat_gram)(
                    mot=row["ortho"],
                    cat_gram=cat_gram,
                    freq=float(row["freqlemlivres"]),
                )
            )

        out.sort(key=lambda x: (x.mot, x.cat_gram))  # We need to bisect on this.

        # Second pass: populate variants
        for row in rows:
            str_lemme = row["lemme"]
            cat_gram = CatGram.parse(row['cgram'])
            lemme_pos = bisect_left(out, (str_lemme, cat_gram), key=lambda x: (x.mot, x.cat_gram))
            if lemme_pos > len(out) or out[lemme_pos].mot != str_lemme:
                continue  # Unknown word
            lemme = out[lemme_pos]

            if lemme.cat_gram == CatGram.NOM:
                genres = match_enum_or_all(row["genre"], Genre)
                nombres = match_enum_or_all(row["nombre"], Nombre)
                for genre, nombre in itertools.product(genres, nombres):
                    variant = Nom.Variant(genre=genre, nombre=nombre)
                    lemme._variants[variant] = row["ortho"]

            elif lemme.cat_gram == CatGram.VERBE:
                infover = row["infover"].split(";")
                for raw_ver in infover:
                    ver = raw_ver.split(":")

                    temps = None
                    personne = None
                    if ver[0] == "inf":
                        temps = Temps(ver[0])
                    elif ver[0] == "ind":
                        temps_select = ":".join(ver[0:2])
                        if temps_select not in Temps:
                            continue
                        temps = Temps(temps_select)
                        personne = Personne(ver[2])
                    else:
                        continue

                    variant = Verbe.Variant(temps=temps, personne=personne)
                    lemme._variants[variant] = row["ortho"]

        return cls(out)

    def most_common(
        self, cat_gram: CatGram, threshold: t.Optional[int] = None
    ) -> list[Mot]:
        if threshold is None:
            try:
                threshold = self.PRESET_THRESHOLD_BY_CAT[cat_gram]
            except KeyError as exn:
                raise ValueError(
                    f"No threshold preset for grammatical category {cat_gram}, "
                    "please provide a threshold manually"
                ) from exn
        out = list(filter(lambda word: word.cat_gram == cat_gram, self.dataset))
        out.sort(key=lambda word: word.freq, reverse=True)
        return out[:threshold]