from typing import Tuple, List, Dict, Optional
import sys
import logging
from os import cpu_count, getpid, kill
from signal import SIGKILL
from argparse import ArgumentParser
from queue import Queue, Empty
from threading import Thread, local
from time import time
from datetime import date, datetime
from functools import reduce

from ratelimit import limits, sleep_and_retry

from yql.api.v1.client import YqlClient
from yt.wrapper import YtClient, YPath, create, read_table, write_table, row_count, copy

from travel.library.python.tools import replace_args_from_env
from travel.hotels.lib.python3.cli.cli import create_progress_bar
from travel.hotels.lib.python3.requests_utils.requests_utils import requests_retry_session
from travel.hotels.lib.python3.tvm.tvm import Tvm

from travel.hotels.lib.python3.yql.yqllib import wait_results, run_yql_file as _run_yql_file
from travel.hotels.lib.python3.yt.versioned_path import VersionedPath, DEFAULT_CLEANUP_STRATEGY, parse_cleanup_strategy
from travel.hotels.lib.python3.yt.ytlib import join, ensure_table_exists, merge_chunks, link, get_default_user_path


logging.basicConfig(
    stream=sys.stdout,
    level=logging.INFO,
)

LOG = logging.getLogger('main')

tvm_logger = logging.getLogger('TVM')
tvm_logger.setLevel(logging.WARNING)


urllib3_connection_logger = logging.getLogger('urllib3.connectionpool')
urllib3_connection_logger.setLevel(logging.ERROR)


def on_error_log_and_exit(func):
    def wrapper(*args, **kwargs):
        try:
            func(*args, **kwargs)
        except Exception as e:
            LOG.exception(e)
            kill(getpid(), SIGKILL)
    return wrapper


def default_worker_threads():
    count = cpu_count()
    if not count:
        count = 4
    return 2 * count


class SuggestClient:

    SUGGEST_LIMIT = 3

    def __init__(self, url, tvm_service_id, tvm_client_id, tvm_client_secret):
        self.url = url

        self.tvm = Tvm(tvm_service_id, tvm_client_id, tvm_client_secret)

        self.thread_local = local()
        self.thread_local.session = requests_retry_session()

    @sleep_and_retry
    @limits(calls=100, period=1)  # calls/period RPS
    def request(self, query: str) -> List:
        rsp = self.thread_local.session.get(
            self.url,
            params={
                'query': query,
                'limit': self.SUGGEST_LIMIT,
                'language': 'ru',
                'domain': 'ru',
                'sessionId': 'metrics-builder-session',
                'requestIndex': 0,
            },
            headers={
                'X-Ya-YandexUid': '123',
                'X-Ya-Session-Key': '00000000-0000-0000-0000-000000000000',
                'X-Ya-Service-Ticket': self.tvm.get_ticket()
            }
        )
        if rsp.status_code != 200:
            raise Exception(f'Status code must be 200 but was {rsp.status_code}')
        return reduce(lambda acc, group: acc + group['items'], rsp.json()['groups'], [])


class SuggestMetricTask:
    QUERY_TYPE = None

    def __init__(self, row: Dict, suggest_type: str):
        self.id = row['id']
        self.hash = row['hash']
        self.count = row['count']
        self.suggest_type = suggest_type

        self.text = None  # Redefined in descendants
        self.suggest_client = None

    def get_query_type(self):
        if self.QUERY_TYPE is None:
            raise Exception('QUERY_TYPE must be redefined in descendant class')

        return self.QUERY_TYPE

    def run_task(self):
        if self.suggest_client is None:
            raise Exception('Suggest client must be initialized before task run')

        recall, recall_1, show = self.get_recall_and_show_metric()
        if recall:
            saved, saved_1, text_len = self.get_saved_metric()
        else:
            saved, saved_1, text_len = 0, 0, len(self.text)
        return {
            'id': self.id,
            'hash': self.hash,
            'query': self.text,
            'count': self.count,
            'queryType': self.get_query_type(),
            'suggestType': self.suggest_type,
            'recall': 1 if recall else 0,
            'recall_1': 1 if recall_1 else 0,
            'show': 1 if show else 0,
            'savedLength': saved,
            'savedLength_1': saved_1,
            'queryLength': text_len,
        }

    def get_recall_and_show_metric(self) -> Tuple[bool, bool, bool]:
        """
        :return: Tuple of recall, recall_1  and show metrics
        """
        items = self.request(self.text)
        item_position_in_suggest = self.item_position_in_suggest(items)

        return item_position_in_suggest is not None, item_position_in_suggest == 0, len(items) != 0

    def get_saved_metric(self) -> Tuple[int, int, int]:
        """
        :return: Tuple of saved, saved_1 and text length values
        """
        saved = 0
        saved_1 = 0
        text_len = len(self.text)

        for i in range(1, text_len - 1):
            items = self.request(self.text[:i])
            item_position = self.item_position_in_suggest(items)

            if item_position is None:
                continue
            if saved == 0:
                saved = text_len - i
            if item_position == 0:
                saved_1 = text_len - 1
                break

        return saved, saved_1, text_len

    def item_position_in_suggest(self, items: List) -> Optional[int]:
        for index, item in enumerate(items):
            if self.item_match(item):
                return index
        return None

    def item_match(self, item: Dict) -> bool:
        raise NotImplementedError()

    def request(self, query: str):
        return self.suggest_client.request(query)

    def __repr__(self):
        return f'{self.__class__.__name__}({vars(self)})'


class RegionSuggestMetricTask(SuggestMetricTask):
    QUERY_TYPE = 'region'

    def __init__(self, row: Dict, suggest_type: str):
        super().__init__(row, suggest_type)

        self.text = row['name']
        self.geo_id = row['geoId']

    def item_match(self, item: Dict) -> bool:
        return item['redirect_params']['type'] == 'region' and item['redirect_params']['geo_id'] == self.geo_id


class HotelSuggestMetricTask(SuggestMetricTask):
    QUERY_TYPE = 'hotel'

    def __init__(self, row: Dict, suggest_type: str):
        super().__init__(row, suggest_type)

        self.text = row['name']
        self.permalink = row['permalink']

    def item_match(self, item: Dict) -> bool:
        return item['redirect_params']['type'] == 'hotel' and item['redirect_params']['permalink'] == str(self.permalink)


class SuggestMetricsCalculator:

    TVM_CLIENT_ID = 2002574
    TVM_SERVICE_ID = 2002546

    QUEUE_TIMEOUT_SECS = 1000

    DATE_FORMAT = '%Y-%m-%d'

    DECIMAL_WHOLE_PART_PRECISION = 35
    DECIMAL_SCALE = 5

    REQUEST_LIMIT = 2000
    DAYS_LIMIT = 30

    SUGGEST_TYPE = 'hotels'

    parser = ArgumentParser()
    parser.add_argument('--yql-token', required=True)
    parser.add_argument('--yt-token', required=True)
    parser.add_argument('--yt-cluster', default='hahn')
    parser.add_argument("--tvm-service-id", default=TVM_SERVICE_ID, type=int)
    parser.add_argument("--tvm-client-id", default=TVM_CLIENT_ID, type=int)
    parser.add_argument("--tvm-client-secret", required=True)
    parser.add_argument('--suggest-url', default='http://api.travel-balancer-test.yandex.net/api/hotels_portal/v1/suggest')
    parser.add_argument('--region-requests-table', type=YPath, default=None)
    parser.add_argument('--hotel-requests-table', type=YPath, default=None)
    parser.add_argument('--search-requests-table', type=YPath, default=None)
    parser.add_argument('--prepare', action='store_true')
    parser.add_argument('--limit', type=int, default=REQUEST_LIMIT)
    parser.add_argument('--days', type=int, default=DAYS_LIMIT)
    parser.add_argument('--date', type=lambda x: datetime.strptime(x, SuggestMetricsCalculator.DATE_FORMAT).date(), default=date.today())
    parser.add_argument('--target-folder', default=get_default_user_path('suggest/metrics'), type=YPath)
    parser.add_argument('--suggest-aggregated-metrics-table', default=get_default_user_path('metrics/suggest-metrics-aggregated'), type=YPath)
    parser.add_argument('--worker-thread-count', default=default_worker_threads(), type=int)
    parser.add_argument('--cleanup-strategy', type=parse_cleanup_strategy, default=DEFAULT_CLEANUP_STRATEGY)

    def __init__(self, args):
        args = self.parser.parse_args(replace_args_from_env(args))

        self.yql_client = YqlClient(
            token=args.yql_token,
            db=args.yt_cluster,
        )

        self.reading_client = YtClient(
            proxy=args.yt_cluster,
            token=args.yt_token,
        )

        self.writing_client = YtClient(
            proxy=args.yt_cluster,
            token=args.yt_token,
        )

        self.url = args.suggest_url
        self.tvm_service_id = args.tvm_service_id
        self.tvm_client_id = args.tvm_client_id
        self.tvm_client_secret = args.tvm_client_secret

        self.region_requests_user_table = args.region_requests_table
        self.hotel_requests_user_table = args.hotel_requests_table
        self.prepare = args.prepare

        self.date = args.date
        self.days = args.days
        self.limit = args.limit

        self.target_folder = args.target_folder
        self.suggest_aggregated_metrics_table = args.suggest_aggregated_metrics_table
        self.cleanup_strategy = args.cleanup_strategy

        self.worker_thread_count = args.worker_thread_count
        self.finished_workers = 0
        self.failed_workers = 0

        self.task_queue = Queue(maxsize=10*self.worker_thread_count)
        self.result_queue = Queue(maxsize=10*self.worker_thread_count)

        self.fully_read = False
        self.fully_processed = False

        self.start_timestamp = int(time())

    def run(self) -> None:
        # TODO: Add exception throw on thread fail
        with VersionedPath(self.target_folder, yt_client=self.writing_client, cleanup_strategy=self.cleanup_strategy) as vp:
            self.working_folder = vp
            self.region_requests_table = join(self.working_folder, 'region-requests')
            self.hotel_requests_table = join(self.working_folder, 'hotel-requests')

            if self.prepare:
                if self.hotel_requests_user_table is None and self.region_requests_user_table is None:
                    wait_results(
                        self.run_yql_file(
                            '/prepare_regions.yql',
                            'Prepare regions bucket',
                            parameters={
                                '$region_target_table': str(self.region_requests_table),
                                '$limit': self.limit,
                            },
                            sync=False,
                        ),
                        self.run_yql_file(
                            '/prepare_hotels.yql',
                            'Prepare hotels bucket',
                            parameters={
                                '$hotel_target_table': str(self.hotel_requests_table),
                                '$limit': self.limit,
                            },
                            sync=False,
                        ),
                    )
                else:
                    LOG.error('Parameters --prepare and --region-requests-table, --hotel-requests-table or --search-requests-table are mutually exclusive')
                    raise Exception('Incorrect parameters')
            else:
                if self.hotel_requests_user_table is not None:
                    copy(self.hotel_requests_user_table, self.hotel_requests_table, client=self.writing_client)
                else:
                    self.hotel_requests_table = None

                if self.region_requests_user_table is not None:
                    copy(self.region_requests_user_table, self.region_requests_table, client=self.writing_client)
                else:
                    self.region_requests_table = None

            threads = []

            reader = Thread(target=self.read_tasks, name='reader')
            reader.start()
            threads.append(reader)

            for i in range(self.worker_thread_count):
                t = Thread(target=self.work, name=f'worker-thread-{i}')
                t.start()
                threads.append(t)

            writer = Thread(target=self.write_result, name='writer')
            writer.start()
            threads.append(writer)

            for t in threads:
                t.join()

    @on_error_log_and_exit
    def work(self):
        suggest_client = SuggestClient(self.url, self.tvm_service_id, self.tvm_client_id, self.tvm_client_secret)

        try:
            while not self.fully_read or not self.task_queue.empty():
                try:
                    task = self.task_queue.get(timeout=self.QUEUE_TIMEOUT_SECS)
                except Empty:
                    continue

                try:
                    task.suggest_client = suggest_client
                    result = task.run_task()
                    self.result_queue.put(result)
                except Exception as e:
                    LOG.exception(f'Task failed: {task!r}', exc_info=e)
        except Exception as e:
            self.failed_workers += 1
            raise
        finally:
            self.finished_workers += 1

    @on_error_log_and_exit
    def read_tasks(self):
        if self.hotel_requests_table is not None:
            for row in read_table(self.hotel_requests_table, client=self.reading_client):
                self.task_queue.put(HotelSuggestMetricTask(row, self.SUGGEST_TYPE))

        if self.region_requests_table is not None:
            for row in read_table(self.region_requests_table, client=self.reading_client):
                self.task_queue.put(RegionSuggestMetricTask(row, self.SUGGEST_TYPE))

        self.fully_read = True

    @on_error_log_and_exit
    def write_result(self):
        source_row_count = 0
        if self.hotel_requests_table is not None:
            source_row_count += row_count(self.hotel_requests_table, client=self.writing_client)
        if self.region_requests_table is not None:
            source_row_count += row_count(self.region_requests_table, client=self.writing_client)

        def data():
            processed_row_count = 0
            with create_progress_bar('Rows processed', source_row_count) as bar:
                while not self.finished_workers == self.worker_thread_count or not self.result_queue.empty():
                    yield self.result_queue.get(timeout=self.QUEUE_TIMEOUT_SECS)
                    processed_row_count += 1
                    bar.update(processed_row_count)

            LOG.info('Data fully processed {}/{} rows\r'.format(processed_row_count, source_row_count))

        result_table = join(self.working_folder, 'result')
        schema = [
            {"name": "id", "type": "string"},
            {"name": "hash", "type": "string"},
            {"name": "query", "type": "string"},
            {"name": "count", "type": "int64"},
            {"name": "queryType", "type": "string"},
            {"name": "suggestType", "type": "string"},
            {"name": "recall", "type": "int64"},
            {"name": "recall_1", "type": "int64"},
            {"name": "show", "type": "int64"},
            {"name": "savedLength", "type": "int64"},
            {"name": "savedLength_1", "type": "int64"},
            {"name": "queryLength", "type": "int64"},
        ]
        create('table', result_table, attributes={'schema': schema}, client=self.writing_client)
        write_table(result_table, data(), client=self.writing_client)

        if self.failed_workers != 0:
            raise Exception('Some workers failed')

        LOG.info(f'Data fully written into destination table {link(result_table)}')

        schema = [
            {"name": "timestamp", "type": "uint64"},
            {"name": "date", "type": "string"},
            {"name": "queryType", "type": "string"},
            {"name": "suggestType", "type": "string"},
            {"name": "weighted", "type": "boolean"},
            {"name": "recall", "type": "double"},
            {"name": "recall_1", "type": "double"},
            {"name": "show", "type": "double"},
            {"name": "saved", "type": "double"},
            {"name": "saved_1", "type": "double"},
        ]
        ensure_table_exists(self.suggest_aggregated_metrics_table, yt_client=self.writing_client, schema=schema)

        queries = []
        for q in [
            ('/aggreagted.yql', 'Calculate aggregated metrics'),
            ('/aggreagted_all.yql', 'Calculate aggregated metrics all'),
            ('/weighted.yql', 'Calculate aggregated metrics weighted'),
            ('/weighted_all.yql', 'Calculate aggregated metrics weighted all'),
        ]:
            res = self.run_yql_file(
                resource_name=q[0],
                title=q[1],
                parameters={
                    '$metric_table': str(result_table),
                    '$target_table': str(self.suggest_aggregated_metrics_table),
                    '$dec_whole': self.DECIMAL_WHOLE_PART_PRECISION,
                    '$dec_scale': self.DECIMAL_SCALE,
                    '$timestamp_sec': self.start_timestamp,
                },
                sync=False,
            )
            queries.append(res)

        wait_results(*queries, print_result=False)
        merge_chunks(self.suggest_aggregated_metrics_table, self.writing_client)

    def run_yql_file(self, resource_name: str, title: str, parameters: Dict = None, sync: bool = True):
        return _run_yql_file(self.yql_client, resource_name, 'SuggestMetrics', title=title, parameters=parameters, sync=sync)


if __name__ == '__main__':
    program = SuggestMetricsCalculator(sys.argv[1:])
    program.run()
