import os
import csv
import subprocess
import json

import vh3
from vh3.lib.services.yt import get_mr_table, get_mr_file
from typing import NamedTuple, Sequence

from sklearn.model_selection import train_test_split
import yt.wrapper as yt

from . import ops

CUDA_BUILD_DEFINITIONS_FLAGS = "-DTENSORFLOW_WITH_CUDA=1 -DCUDA_VERSION=10.1 -DNO_DEBUGINFO"
YT_JOB_PORTO_LAYER = '114d79b2-a314-4e12-9bc2-7bda7ad3a5a5'

YT_PROXY = "hahn.yt.yandex.net"
GPU_POOL = "gpu_geforce_1080ti"


class ServiceContext(vh3.DefaultContext):
    pass


@vh3.decorator.operation(vh3.mr_run_base, deterministic=True, owner='quoter')
@vh3.decorator.autorelease_to_nirvana_on_trunk_commit(
    version='https://nirvana.yandex-team.ru/alias/operation/maps_mrc_upload_pairs_file_to_yt/0.0.1',
    script_method='upload_pairs_file_to_yt',
)
@vh3.decorator.nirvana_names_transformer(vh3.name_transformers.snake_to_dash, options=True, inputs=False, outputs=False)
def upload_pairs_file_to_yt(
    pairs_file: vh3.JSON,
) -> vh3.MRTable[yt.TablePath]:

    output_table_path = str(vh3.runtime.get_mr_output_path())
    pairs = json.load(open(str(pairs_file)))

    def generator():
        for first_id, second_id in pairs['pairs']:
            yield {
                'first': pairs['features'][str(first_id)],
                'second': pairs['features'][str(second_id)],
            }

    yt.write_table(output_table_path, generator())


@vh3.decorator.operation(vh3.mr_run_base, deterministic=True, owner='quoter')
@vh3.decorator.autorelease_to_nirvana_on_trunk_commit(
    version='https://nirvana.yandex-team.ru/alias/operation/maps_mrc_filter_match_candidates/0.1.1',
    script_method='filter_match_candidates',
)
@vh3.decorator.nirvana_names_transformer(vh3.name_transformers.snake_to_dash, options=True, inputs=False, outputs=False)
def filter_match_candidates(
    binary: vh3.Executable,
    features_table: vh3.MRTable[yt.TablePath],
    objects_file: vh3.MRFile[yt.FilePath],
    distance: vh3.Integer,
    heading_diff: vh3.Number,
    min_box_size: vh3.Integer,
    ignore_detections: vh3.Boolean,
) -> vh3.JSON:
    """Computes objects matches candidates

    Args:
        binary (vh3.Executable): _description_
        features_table (vh3.MRTable[yt.TablePath]): _description_
        objects_file (vh3.MRFile): _description_
        distance (vh3.Integer): _description_
        heading_diff (vh3.Number): _description_
        min_box_size (vh3.Integer): _description_
        ignore_detections (vh3.Boolean): _description_

    Returns:
        vh3.File: file with matches
    """
    output_file_path = str(vh3.runtime.get_output_path())
    cmd = [
        binary,
        "--feature-table", str(features_table),
        "--object-file", str(objects_file),
        "--distance", str(distance),
        "--heading-diff", str(heading_diff),
        "--min-box-size", str(min_box_size),
        "--pair-file", output_file_path
    ]

    if ignore_detections:
        cmd.extend(["--ignore-detections"])

    subprocess.check_call(cmd)


class GenerateDatasetOutput(NamedTuple):
    columns_descriptions: vh3.MRTable[yt.TablePath]
    dataset: vh3.MRTable[yt.TablePath]


@vh3.decorator.operation(vh3.mr_run_base, deterministic=True, owner='quoter')
@vh3.decorator.autorelease_to_nirvana_on_trunk_commit(
    version='https://nirvana.yandex-team.ru/alias/operation/maps_mrc_generate_dataset/0.1.0',
    script_method='generate_dataset',
)
@vh3.decorator.nirvana_names_transformer(vh3.name_transformers.snake_to_dash, options=True, inputs=False, outputs=False)
def generate_dataset(
    binary: vh3.Executable,
    feature_table: vh3.MRTable[yt.TablePath],
    pairs_table: vh3.MRTable[yt.TablePath],
    cluster_file: vh3.MRFile[yt.FilePath],
    object_file: vh3.MRFile[yt.FilePath],
) -> GenerateDatasetOutput:
    """Generates dataset

    Args:
        binary (vh3.Executable): _description_
        feature_table (vh3.MRTable[yt.TablePath]): _description_
        pairs_table (vh3.MRTable[yt.TablePath]): _description_
        cluster_file (vh3.MRFile): _description_
        object_file (vh3.MRFile): _description_
        command (vh3.String): _description_

    Returns:
        GenerateDatasetOutput: dataset description
    """
    column_desc_path = vh3.runtime.get_mr_output_path("columns_descriptions")
    dataset_path = vh3.runtime.get_mr_output_path("dataset")
    cmd = [
        binary,
        "--feature-table", str(feature_table),
        "--pair-table", str(pairs_table),
        "--cluster-file", str(cluster_file),
        "--object-file", str(object_file),
        "--output-columns-description-table", str(column_desc_path),
        "--output-factors-table", str(dataset_path)
    ]

    subprocess.check_call(cmd)
    return GenerateDatasetOutput(column_desc_path, dataset_path)


class SplitTrainTestOutput(NamedTuple):
    train: vh3.TSV
    test: vh3.TSV


@vh3.decorator.operation(vh3.job_run_base, deterministic=True, owner='quoter')
@vh3.decorator.autorelease_to_nirvana_on_trunk_commit(
    version='https://nirvana.yandex-team.ru/alias/operation/maps_mrc_join_and_split_train_test/0.0.1',
    script_method='join_and_split_train_test',
)
@vh3.decorator.nirvana_names_transformer(vh3.name_transformers.snake_to_dash, options=True, inputs=False, outputs=False)
def join_and_split_train_test(
    files: Sequence[vh3.TSV],
    test_size: vh3.Number,
) -> SplitTrainTestOutput:
    dataset = []
    DELIMITER = "\t"
    for file in files:
        with open(file, 'r') as f:
            reader = csv.reader(f, delimiter=DELIMITER)
            dataset.extend(list(reader))

    train, test = train_test_split(dataset, test_size=test_size)

    train_path = vh3.runtime.get_output_path("train")
    test_path = vh3.runtime.get_output_path("test")

    with open(str(train_path), 'w') as f:
        writer = csv.writer(f, delimiter=DELIMITER)
        writer.writerows(train)

    with open(str(test_path), 'w') as f:
        writer = csv.writer(f, delimiter=DELIMITER)
        writer.writerows(test)

    return SplitTrainTestOutput(train_path, test_path)


def make_dataset(filter_match_candidates_binary: vh3.Executable,
                 calc_visibility_factors_binary: vh3.Executable,
                 dataset_path: str):
    DISTANCE = 50
    HEADING_DIFF = 90
    MIN_BOX_SIZE = 30

    features_table = get_mr_table(os.path.join(dataset_path, "feature"))
    cluster_file = get_mr_file(os.path.join(dataset_path, "cluster.json"))
    objects_file = get_mr_file(os.path.join(dataset_path, "object.json"))

    pairs_file = filter_match_candidates(
        filter_match_candidates_binary.arcadia_project,
        features_table,
        objects_file,
        distance=DISTANCE,
        heading_diff=HEADING_DIFF,
        min_box_size=MIN_BOX_SIZE,
        ignore_detections=True,
        max_ram=1000,
    )

    pairs_table = upload_pairs_file_to_yt(pairs_file, max_ram=1000)

    return generate_dataset(
        binary=calc_visibility_factors_binary.arcadia_project,
        feature_table=features_table,
        pairs_table=pairs_table,
        cluster_file=cluster_file,
        object_file=objects_file,
        max_ram=1000,
        job_layer=[YT_JOB_PORTO_LAYER]
        )


@vh3.decorator.graph()
def train_catboost_model() -> None:

    dataset_paths = [
        "//home/maps/core/mrc/signs_map/dataset_0",
        "//home/maps/core/mrc/signs_map/dataset_1",
        "//home/maps/core/mrc/signs_map/dataset_2",
    ]

    filter_match_candidates_binary = ops.build_arcadia_project(
        targets="maps/wikimap/mapspro/services/mrc/eye/experiments/signs_map/matcher/filter_match_candidates",
        arts="maps/wikimap/mapspro/services/mrc/eye/experiments/signs_map/matcher/filter_match_candidates/filter_match_candidates",
        arcadia_revision=9154614,
    )

    calc_visibility_factors_binary = ops.build_arcadia_project(
        targets="maps/wikimap/mapspro/services/mrc/eye/experiments/signs_map/matcher/find_missing_object/calc_visibility_factors",
        arts="maps/wikimap/mapspro/services/mrc/eye/experiments/signs_map/matcher/find_missing_object/calc_visibility_factors/calc_visibility_factors",
        arcadia_revision=9645695,
        definition_flags="-DTENSORFLOW_WITH_CUDA=1 -DCUDA_VERSION=10.1 -DNO_DEBUGINFO",
        arcadia_patch='rb:2688786',  # use 'rb:${REVIEW_ID}' to apply patch from review
    )

    datasets = []
    columns_descriptions = None

    for dataset_path in dataset_paths:
        columns_descriptions, dataset = make_dataset(
            filter_match_candidates_binary,
            calc_visibility_factors_binary,
            dataset_path)
        datasets.append(dataset)

    joined_dataset = ops.concatenate_mr_tables(tables=datasets)
    train_dataset, test_dataset = ops.mr_table_split(
        input=joined_dataset,
        ratio=0.9,
        seed=0)

    learn_result = ops.cat_boost_train(
        learn=train_dataset,
        test=test_dataset,
        cd=columns_descriptions,
        iterations=5000,
        fstr_type="PredictionValuesChange",
        prediction_type="Probability",
        args="--custom-loss=AUC,Precision,Recall,F1")

    # analysis_result =
    ops.cat_boost_model_analysis(
        pool=train_dataset,
        test=test_dataset,
        cd=columns_descriptions,
        model_bin=learn_result.model_bin
    )
