from enum import Enum, auto
from datetime import datetime
import os
import sys
import logging

from json import dumps

import yt.wrapper as yt

import requests
import ticket_parser2

import ticket_parser2.api.v1 as tp2

from yql.api.v1.client import YqlClient
from yql.client.parameter_value_builder import YqlParameterValueBuilder as ValueBuilder


class Config:
    TVM_CLIENT_ID_TESTING = 2021255
    DATASYNC_CLIENT_ID_TESTING = 2000060
    TVM_CLIENT_ID_PROD = 2021253
    DATASYNC_CLIENT_ID_PROD = 132

    DATASYNC_URL_PROD = "https://intapi.disk.yandex.net:8443/v1/personality/profile"
    DATASYNC_URL_TEST = "https://api-stable.dst.yandex.net:8443/v1/personality/profile"

    def __init__(self):
        use_test_env = os.getenv("DEV_ENV") == "True"

        if use_test_env:
            self.tvm_client_id = Config.TVM_CLIENT_ID_TESTING
            self.datasync_client_id = Config.DATASYNC_CLIENT_ID_TESTING
            self.datasync_url = Config.DATASYNC_URL_TEST
            self.yt_cluster = "hume"
            self.origin_table_name = f"{self.yt_cluster}.`tmp/stepan_k_vehicles_origin`"
            self.processed_table_path = "tmp/stepan_k_vehicles_processed"
            self.task_stats_table_path = "tmp/stepan_k_task_stats"
        else:
            self.tvm_client_id = Config.TVM_CLIENT_ID_PROD
            self.datasync_client_id = Config.DATASYNC_CLIENT_ID_PROD
            self.datasync_url = Config.DATASYNC_URL_PROD
            self.yt_cluster = "hahn"
            self.origin_table_name = f"{self.yt_cluster}.`home/verticals/__private_export/navigator-carfax-project/navigator_vehicles`"
            self.processed_table_path = "home/maps/analytics/legacy/nirvana/navi/team-x-vin-transfer/processed_records"
            self.task_stats_table_path = "home/maps/analytics/legacy/nirvana/navi/team-x-vin-transfer/task_stats"

        self.processed_table_name = f"{self.yt_cluster}.`{self.processed_table_path}`"

        self.tvm_secret = self._env_var(
            "TVM_SECRET", "TVM_SECRET env var must be set")

        self.yql_token = self._env_var(
            "YQL_TOKEN", "YQL_TOKEN env var must be set")

    def _env_var(self, var_name, error_message):
        env_var = os.getenv(var_name)
        if (env_var is not None):
            return env_var
        else:
            raise ValueError(error_message)


class ServiceTicketProvider:
    def __init__(self, config):
        tvm_settings = tp2.TvmApiClientSettings(
            self_client_id=config.tvm_client_id,
            self_secret=config.tvm_secret,
            dsts={"datasync": config.datasync_client_id})

        self._tvm_client = tp2.TvmClient(tvm_settings)

    def service_ticket_for_datasync(self):
        return self._tvm_client.get_service_ticket_for("datasync")


class DatasyncApi:
    def __init__(self, config, ticket_provider):
        self.url = f"{config.datasync_url}/navi/car_info"
        self.service_ticket = ticket_provider.service_ticket_for_datasync()

    def _headers(self, uid):
        return {
            "X-Ya-Service-Ticket": self.service_ticket,
            "X-Uid": str(uid)
        }

    def user_cars(self, uid):
        response = requests.get(self.url, headers=self._headers(uid))
        response.raise_for_status()
        return response.json()["items"]

    def update_user_car(self, uid, car):
        car_id = car["id"]
        response = requests.put(
            f"{self.url}/{car_id}", headers=self._headers(uid), json=car)
        response.raise_for_status()

    def create_user_car(self, uid, car):
        response = requests.post(
            self.url, headers=self._headers(uid), json=car)
        response.raise_for_status()

    def delete_user_car(self, uid, car):
        car_id = car["id"]
        response = requests.delete(
            f"{self.url}/{car_id}", headers=self._headers(uid))
        response.raise_for_status()


def unix_time(date_time):
    return (int(datetime.timestamp(date_time) * 1000))


def from_unix_time(timestamp):
    datetime.utcfromtimestamp(timestamp / 1000)


def from_yt_timestamp(timestamp):
    datetime.utcfromtimestamp(timestamp / 1000000)


def init_logger():
    root = logging.getLogger()
    root.setLevel(logging.INFO)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    handler.setFormatter(formatter)
    root.addHandler(handler)


def vin_matches_masked_vin(vin, masked_vin):
    if (len(vin) != len(masked_vin)):
        return False

    for index in range(len(vin)):
        if masked_vin[index] != '*' and masked_vin[index].lower() != vin[index].lower():
            return False

    return True


STATUS_OK = "OK"
STATUS_NOT_FOUND = "not found in datasync"
STATUS_VIN_ALREADY_PRESENT = "vin already present"
STATUS_DATASYNC_READ_ERROR = "datasync read error"
STATUS_DATASYNC_WRITE_ERROR = "datasync write error"


class Stats:
    def __init__(self):
        self.task_start_time = datetime.now()
        self.task_finish_time = 0
        self.total_unprocessed_records_count = 0
        self.processed_records_count = 0
        self.transferred_records_count = 0
        self.not_found_records_count = 0
        self.already_has_vin_records_count = 0
        self.errors_count = 0

    def mark_processed(self, status):
        self.processed_records_count += 1
        if status == STATUS_OK:
            self.transferred_records_count += 1
        elif status == STATUS_NOT_FOUND:
            self.not_found_records_count += 1
        elif status == STATUS_VIN_ALREADY_PRESENT:
            self.already_has_vin_records_count += 1
        else:
            self.errors_count += 1

    def mark_task_finished(self):
        self.task_finish_time = datetime.now()


class UnprocessedRecord:
    def __init__(self, user_id, vin_code, vin_code_id):
        self.user_id = user_id
        self.vin_code = vin_code
        self.vin_code_id = vin_code_id


class YtWrapper:
    def __init__(self, config):
        self.yql_client = YqlClient(
            token=config.yql_token, db=config.yt_cluster)

        self.processed_table_name = config.processed_table_name
        self.processed_table_path = config.processed_table_path
        self.origin_table_name = config.origin_table_name
        self.task_stats_table_path = config.task_stats_table_path

        yt.config["proxy"]["url"] = config.yt_cluster

    def get_unprocessed_records(self, limit=None):
        query = f"""
SELECT origin.`user_id`, origin.`vin_code`, origin.`vin_code_id`
FROM {self.origin_table_name} AS origin LEFT JOIN {self.processed_table_name} AS processed
ON origin.`user_id` == processed.`user_id` AND origin.`vin_code` == processed.`vin_code`
WHERE processed.`user_id` IS NULL
"""

        if (limit != None):
            query += f" LIMIT {limit}"

        request = self.yql_client.query(query, syntax_version=1)
        request.run()
        results = request.get_results()

        if results.is_success:
            for table in results:
                table.fetch_full_data()

                records = []
                for row in table.rows:
                    user_id, vin, vin_code_id = row
                    records.append(UnprocessedRecord(
                        user_id, vin, vin_code_id))

                return records

        else:
            logging.error(
                f"Failed to get unprocessed records from YT: status: {results.status}")
            if results.errors:
                for error in results.errors:
                    logging.error(str(error))

            raise RuntimeError("Failed to get records from YT")

    def save_as_processed(self, processed_records):

        data = zip(
            ['user_id', 'Uint64'],
            ['vin_code', 'Utf8'],
            ['vin_code_id', 'Uint64'],
            ['status', 'Utf8'],
            ['timestamp_processed', 'Uint64']
        )

        column_names, column_types = map(lambda tuple: list(tuple), data)

        self.yql_client.write_table(
            self.processed_table_path, processed_records, column_names, column_types, append=True)

    def save_task_stats(self, stats):

        data = zip(
            ['task_start_time', 'Uint64', unix_time(stats.task_start_time)],
            ['task_finish_time', 'Uint64',  unix_time(stats.task_finish_time)],
            ['unprocessed_records_count', 'Uint64',
                stats.total_unprocessed_records_count],
            ['processed_records_count', 'Uint64', stats.processed_records_count],
            ['transferred_records_count', 'Uint64',
                stats.transferred_records_count],
            ['not_found_records_count', 'Uint64', stats.not_found_records_count],
            ['already_has_vin_records_count', 'Uint64',
                stats.already_has_vin_records_count],
            ['errors_count', 'Uint64', stats.errors_count]
        )

        column_names, column_types, values = map(
            lambda tuple: list(tuple), data)

        self.yql_client.write_table(
            self.task_stats_table_path, [values], column_names, column_types, append=True)

    def _full_path(self, table_name):
        return f"//{table_name}"

    def merge_tables_chunks(self):

        for table_path in [self.task_stats_table_path, self.processed_table_path]:
            full_table_path = self._full_path(table_path)
            yt.run_merge(full_table_path, full_table_path,
                         spec={"combine_chunks": "true"})


def write_car_to_datasync(datasync_api, user_car, record):
    user_car["vin"] = record.vin_code
    try:
        datasync_api.update_user_car(record.user_id, user_car)
        return STATUS_OK
    except Exception as e:
        logging.error(f"{status}: {str(e)}")
        return STATUS_DATASYNC_WRITE_ERROR


def process_records(datasync_api, yt_wrapper, stats, unprocessed_records):

    processed_records = []

    for record in unprocessed_records:
        logging.info(
            f"processing record: user_id: {record.user_id}, vin: {record.vin_code}")

        status = STATUS_NOT_FOUND
        try:
            user_cars = datasync_api.user_cars(record.user_id)
        except Exception as e:
            status = STATUS_DATASYNC_READ_ERROR
            logging.error(f"{status}: {str(e)}")
            user_cars = []

        user_car = next((car for car in user_cars if car.get(
            "vinId") == str(record.vin_code_id)), None)
        if user_car != None:
            if not user_car.get("vin"):
                status = write_car_to_datasync(datasync_api, user_car, record)
            elif user_car.get("vin") == record.vin_code:
                status = STATUS_OK
            else:
                status = STATUS_VIN_ALREADY_PRESENT

        processed_records.append(
            [record.user_id, record.vin_code, record.vin_code_id, status, unix_time(datetime.now())])

        stats.mark_processed(status)

    yt_wrapper.save_as_processed(processed_records)


def main():

    init_logger()
    stats = Stats()

    config = Config()
    service_ticket_provider = ServiceTicketProvider(config)
    datasync_api = DatasyncApi(config, service_ticket_provider)
    yt_wrapper = YtWrapper(config)

    logging.info("getting unprocessed records..")
    unprocessed_records = yt_wrapper.get_unprocessed_records()
    stats.total_unprocessed_records_count = len(unprocessed_records)

    records_limit = os.getenv("RECORDS_LIMIT")
    if (records_limit != None):
        unprocessed_records = unprocessed_records[:int(records_limit)]

    process_records(datasync_api, yt_wrapper, stats, unprocessed_records)

    stats.mark_task_finished()
    logging.info("saving task stats..")
    yt_wrapper.save_task_stats(stats)

    logging.info("merging table chunks..")
    yt_wrapper.merge_tables_chunks()


if __name__ == "__main__":
    main()
