# -*- coding: utf-8 -*-

import csv
import importlib
import io
import logging
from argparse import ArgumentParser
from contextlib import contextmanager

from library.python import resource
from yql.api.v1.client import YqlClient

from travel.hotels.lib.python3.yql import yqllib
from travel.hotels.lib.python3.yt import ytlib
from travel.hotels.lib.python3.yt.temp_path import TmpPath
from travel.hotels.lib.python3.yt.versioned_path import VersionedPath
from travel.hotels.lib.python3.yt.ytlib import transfer_results
from travel.library.python.tools import replace_args_from_env


class Runner(object):
    def __init__(self, args):
        assert args.yql_token or args.yql_token_path
        yql_args = {
            'db': args.yt_proxy,
            'token': args.yql_token,
            'token_path': args.yql_token_path,
        }
        self.yql_client = YqlClient(**yql_args)
        if args.yt_token or args.yt_token_path:
            yt_config = {
                'token': args.yt_token,
                'token_path': args.yt_token_path,
            }
            self.yt_client = ytlib.create_client(proxy=args.yt_proxy, config=yt_config)
        else:
            self.yt_client = None
        self.args = args

    def replace(self, s, replaces):
        for k, v in replaces.items():
            s = s.replace(k, v)
        return s

    def prepare_args(self, args, replaces, prefix):
        result = dict()
        for p in args:
            parts = p.split('=', 2)
            key = parts[0]
            arg_name, arg_yt_type = Runner.get_name_and_yt_type(key)
            value = self.replace(parts[1], replaces)
            value = self.get_type_converter(arg_name, arg_yt_type)(value)
            if prefix and not arg_name.startswith(prefix):
                arg_name = prefix + arg_name
            current_value = result.get(arg_name)
            if current_value is None:
                result[arg_name] = value
            else:
                if isinstance(current_value, list):
                    current_value.append(value)
                else:
                    result[arg_name] = [current_value, value]
        return result

    def get_result_path(self, query_args, prefix):
        yt_result_path_query_arg_name = self.args.yt_result_path_query_arg_name

        if yt_result_path_query_arg_name:
            yt_result_path_query_arg_name = prefix + yt_result_path_query_arg_name
            return query_args[yt_result_path_query_arg_name]

        return None

    def run(self):
        try:
            processor_name = self.args.processor_name or self.args.query_name
            processor_module = importlib.import_module(f'travel.hotels.tools.run_yql.queries.{processor_name}')
            logging.info(f"Processor {processor_name} loaded")
        except ImportError:
            processor_module = None
            logging.info("No processor found")
            assert not self.args.processor_args
        with self.get_yt_transaction() as transaction_id:
            versioned_path_manager = self.get_versioned_path_manager()
            with self.get_versioned_path(versioned_path_manager) as versioned_path:
                with self.get_temp_path() as temp_path:
                    replaces = dict()
                    if versioned_path:
                        replaces['%VERSIONED_PATH%'] = str(versioned_path)
                    if temp_path:
                        replaces['%TEMP_PATH%'] = str(temp_path)
                    self.upload_csv_resources(replaces)
                    query_args = self.prepare_args(self.args.query_args, replaces, '$')
                    if processor_module:
                        processor_args = self.prepare_args(self.args.processor_args, replaces, None)
                        processor = processor_module.Processor(yt_client=self.yt_client, yql_client=self.yql_client,
                                                               query_args=query_args, **processor_args)
                    else:
                        processor = None
                    if processor:
                        logging.info("Running pre-process")  # Args not logged, may contain tokens and secrets
                        processor.pre_process()
                    logging.info(f"Running query {self.args.query_name}, Args: {query_args}")
                    yqllib.run_yql_file(self.yql_client, self.args.query_name + '.yql', "RunYql",
                                        parameters=query_args, transaction_id=transaction_id)
                    logging.info("Query finished")

                    if processor:
                        logging.info("Running post-process")  # Args not logged, may contain tokens and secrets
                        processor.post_process()

        if self.args.transfer_to_cluster:
            # copy results to another cluster
            if versioned_path:
                versioned_path_manager.transfer_results(self.args.transfer_to_cluster, self.args.yt_token,
                                                        self.args.yt_proxy)
            else:
                result_path = self.get_result_path(query_args, '$')

                if result_path:
                    transfer_results(result_path, self.args.yt_proxy, self.args.transfer_to_cluster, self.args.yt_token)

    def get_versioned_path_manager(self):
        if self.args.versioned_path:
            if self.yt_client is None:
                raise Exception("versioned_path is specified, but no yt client (no token)")
            return VersionedPath(self.args.versioned_path, yt_client=self.yt_client)
        else:
            return None

    @contextmanager
    def get_versioned_path(self, versioned_path_manager):
        if versioned_path_manager:
            with versioned_path_manager as p:
                yield p
        else:
            yield None

    @contextmanager
    def get_temp_path(self):
        if self.args.temp_path:
            if self.yt_client is None:
                raise Exception("temp_path is specified, but no yt client (no token)")
            with TmpPath(self.args.temp_path, yt_client=self.yt_client) as p:
                yield p
        else:
            yield None

    @contextmanager
    def get_yt_transaction(self):
        if self.yt_client is not None:
            with self.yt_client.Transaction() as t:
                yield t.transaction_id
        else:
            yield None

    @staticmethod
    def str_to_bool(x):
        if x.upper() == 'TRUE':
            return True
        if x.upper() == 'FALSE':
            return False
        raise Exception(f"Invalid bool value '{x}'")

    @staticmethod
    def get_type_converter(field_name, yt_type):
        converters = {
            'bool': Runner.str_to_bool,
            'int32': int,
            'string': str,
        }
        c = converters.get(yt_type)
        if c is None:
            raise Exception(f"Cannot find convertor for field {field_name}, yt type is {yt_type}")
        return c

    @staticmethod
    def get_name_and_yt_type(field_name):
        parts = field_name.split('~', 2)
        name = parts[0]
        if len(parts) == 2:
            yt_type = parts[1]
        else:
            yt_type = 'string'
        return name, yt_type

    def upload_csv_resources(self, replaces):
        if not self.args.csv_resources:
            return
        if self.yt_client is None:
            raise Exception("resources is specified, but no yt client (no token)")
        for res_info in self.args.csv_resources:
            yt_path, resource_name = res_info.split('=', 2)
            yt_path = self.replace(yt_path, replaces)  # Use %VERSIONED_PATH%, %TEMP_PATH% and so on
            csv_data = resource.find(resource_name)
            if csv_data is None:
                raise Exception(f"Cannot find resource {resource_name}")
            reader = csv.DictReader(io.StringIO(csv_data.decode('utf-8')), delimiter=',', quotechar='"')
            yt_schema = list()
            col_infos = dict()
            for csv_name in reader.fieldnames:
                yt_name, yt_type = Runner.get_name_and_yt_type(csv_name)
                yt_schema.append({'name': yt_name, 'type': yt_type})
                converter = Runner.get_type_converter(csv_name, yt_type)
                col_infos[csv_name] = yt_name, converter
            yt_rows = list()
            for csv_row in reader:
                yt_row = dict()
                for csv_name, csv_val in csv_row.items():
                    yt_name, converter = col_infos[csv_name]
                    yt_row[yt_name] = converter(csv_val)
                yt_rows.append(yt_row)
            self.yt_client.create("table", yt_path, attributes={"schema": yt_schema}, force=True)
            self.yt_client.write_table(yt_path, yt_rows)


def main():
    FORMAT = '%(asctime)-15s | %(levelname)-4.4s | %(name)-12.12s | %(message)s'
    logging.basicConfig(level=logging.INFO, format=FORMAT)
    parser = ArgumentParser()
    parser.add_argument('--yt-proxy', default='hahn')
    parser.add_argument('--yql-token')
    parser.add_argument('--yql-token-path')
    parser.add_argument('--yt-token')  # YT Token is needed only if versioned-path is specified
    parser.add_argument('--yt-token-path')
    parser.add_argument('--versioned-path')  # Goes to '%VERSIONED_PATH%' argument replacement
    parser.add_argument('--temp-path')  # Goes to '%TEMP_PATH%' argument replacement
    parser.add_argument('--query-name', required=True)
    parser.add_argument('--query-args', nargs='+', default=[])  # Format: key=value
    parser.add_argument('--csv-resources', nargs='+', default=[])  # Format: yt_path=resource_name
    parser.add_argument('--processor-name')  # if not defined, processor is checked by query_name
    parser.add_argument('--processor-args', nargs='+', default=[])
    parser.add_argument("--transfer-to-cluster")
    parser.add_argument("--yt-result-path-query-arg-name")
    args = parser.parse_args(args=replace_args_from_env())
    Runner(args).run()


if __name__ == '__main__':
    main()
