from __future__ import annotations
from tractor.util.dataclasses import construct_from_dict, get_field_list, get_optional_field_list
from tractor.csv.exceptions import (
    EncodingDetectError,
    BadCsvFile,
    BadHeaderRowError,
    NonUniqueColumnsError,
    UnknownColumnsError,
    MissingColumnsError,
    UnexpectedRowValueError,
    EmptyRowValueError,
)
from tractor.csv.sniffer import sniff
from typing import Any, Set, Sequence, Tuple, Type, Dict
import chardet
import csv


def _detect_encoding(data: bytes) -> Tuple[str, float]:
    detect_result = chardet.detect(data)
    return (detect_result["encoding"], detect_result["confidence"])


def decode_bytes_to_str(data: bytes, encoding_detect_threshold: float = 0.9) -> str:
    encoding, confidence = _detect_encoding(data)
    if confidence < encoding_detect_threshold:
        raise EncodingDetectError(encoding, confidence)
    try:
        decoded_str = data.decode(encoding)
    except:
        raise EncodingDetectError(encoding, confidence)
    return decoded_str


# Empty strings are treated as None value to allow usage of optional fields
class CsvParser:
    def __init__(self, data: str, result_class: Type) -> None:
        self._result_class = result_class
        self._csv_reader = self._create_csv_reader(data)
        self._validate_columns()

    def __iter__(self) -> CsvParser:
        return self

    def __next__(self) -> Any:
        data = next(self._csv_reader)
        self._validate_row(data)
        self._delete_empty_optional_fields(data)
        try:
            return construct_from_dict(self._result_class, data)
        except Exception as e:
            raise BadCsvFile() from e

    def _create_csv_reader(self, data: str) -> csv.DictReader:
        dialect = sniff(data)
        return csv.DictReader(data.splitlines(keepends=True), dialect=dialect, restval="")

    def _validate_columns(self) -> None:
        supported_columns = get_field_list(self._result_class)
        expected_columns = supported_columns - get_optional_field_list(self._result_class)
        actual_columns = self._csv_reader.fieldnames
        if actual_columns is None:
            raise BadHeaderRowError()

        unique_actual_columns = set(actual_columns)

        if len(actual_columns) != len(unique_actual_columns):
            raise NonUniqueColumnsError()

        unknown_columns = unique_actual_columns - supported_columns
        if len(unknown_columns) > 0:
            raise UnknownColumnsError()

        missing_columns = expected_columns - unique_actual_columns
        if len(missing_columns) > 0:
            raise MissingColumnsError()

    def _validate_row(self, row: Dict[str, str]) -> None:
        optional_fields = get_optional_field_list(self._result_class)
        for key, value in row.items():
            if key is None:
                raise UnexpectedRowValueError()
            if value == "" and key not in optional_fields:
                raise EmptyRowValueError()

    def _delete_empty_optional_fields(self, row: Dict[str, str]) -> None:
        optional_fields = get_optional_field_list(self._result_class)
        keys_to_delete = []
        for key, value in row.items():
            if value == "" and key in optional_fields:
                keys_to_delete.append(key)

        for key in keys_to_delete:
            del row[key]

    def _get_non_unique_items(self, lst: Sequence) -> Set:
        return {x for x in lst if lst.count(x) > 1}
