# coding: utf8
import json
import os
from datetime import datetime, timedelta

import tempfile
import yt.wrapper as yt
from yt.wrapper.client import Yt

STATIONS_SETTLEMENTS_FILE_NAME = os.path.join(tempfile.gettempdir(), "dict.json")


def create_client(proxy, token):
    return Yt(proxy=proxy, config={"tabular_data_format": "dsv"}, token=token)


class StatTablesHelper(object):
    def __init__(self, ytclient, ytsearchpath):
        self._ytclient = ytclient
        self._ytsearchpath = ytsearchpath

    def get_dates_tables(self, min_date, max_date=None):
        max_date = datetime.now() if not max_date else max_date

        tables = []
        for table_path in self._ytclient.search(self._ytsearchpath):
            if table_path == self._ytsearchpath:
                continue

            date_str = table_path.split('/')[-1]
            date = datetime.strptime(date_str, '%Y-%m-%d')
            if min_date <= date <= max_date:
                tables.append(table_path)

        return tables

    def get_dates_back(self, days_back):
        return self.get_dates_tables(datetime.now() - timedelta(days=days_back))


class SearchStatRunner(object):
    def __init__(self, client, source_tables, stations_settlements, work_path):
        self.stations_settlements = stations_settlements
        self.work_path = work_path
        self.client = client

        self.source_tables = [
            yt.TablePath(
                table,
                client=self.client,
                columns=['from_id', 'to_id', 'transport_type', 'geoid', 't_type_counts', 'service', 'yandexuid'],
            )
            for table in source_tables
        ]

    def get_table(self, name, **kwargs):
        table_path = '{}/{}'.format(self.work_path, name)
        return yt.TablePath(table_path, client=self.client, **kwargs)

    def run(self):
        self.client.mkdir(self.work_path, recursive=True)

        stat_mapped, routes_mapped = self.get_table('stat_mapped'), self.get_table('routes_mapped')
        stat_result, routes_result = self.get_table('stat_result'), self.get_table('routes_result')

        self.client.run_map(
            self.map_stat,
            source_table=self.source_tables,
            destination_table=[stat_mapped, routes_mapped],
            output_format=yt.YsonFormat(
                control_attributes_mode="row_fields",
                table_index_column="@table_index")
        )

        stat_sorted = self.get_table('stat_sorted')
        stat_fields = ['obj_id', 'transport_type', 'geoid']
        self.client.run_sort(stat_mapped, stat_sorted, sort_by=stat_fields)
        self.client.run_reduce(self.reduce_count, stat_sorted, stat_result,
                               reduce_by=stat_fields, output_format=yt.YsonFormat())

        routes_sorted = self.get_table('routes_sorted')
        routes_fields = ['transport_type', 'from_id', 'to_id', 'geoid']
        self.client.run_sort(routes_mapped, routes_sorted, sort_by=routes_fields)
        self.client.run_reduce(self.reduce_count, routes_sorted, routes_result,
                               reduce_by=routes_fields, output_format=yt.YsonFormat())

        return stat_result, routes_result

    def to_table(self, table_ind, data):
        data['@table_index'] = table_ind
        return data

    def map_stat(self, row):
        stat_table, routes_table = 0, 1

        if not row.get('yandexuid'):  # пустой yandexuid - почти наверняка бот
            return

        if row.get('service') not in ['rasp', 'rasp-touch']:
            return

        from_key, to_key = row.get('from_id'), row.get('to_id')
        if not from_key or not to_key:
            return

        try:
            from_type, from_id = from_key[0], from_key[1:]
            to_type, to_id = to_key[0], to_key[1:]
        except Exception:
            return

        geoid = row.get('geoid')
        t_type = row.get('transport_type')
        if not t_type:
            t_type = 'all'

        if not self._has_routes(row):
            return

        # статистика поисков конкретных объектов
        for obj_id in [from_key, to_key]:
            yield self.to_table(stat_table, {
                'obj_id': obj_id,
                'geoid': geoid,
                'transport_type': t_type,
                'service': row.get('service'),
                't_type_counts': row.get('t_type_counts')
            })

        # статистика поиска от - до
        for v_from_id, v_to_id in self.routes_variants(from_key, from_type, from_id, to_key, to_type, to_id):
            yield self.to_table(routes_table, {
                'transport_type': t_type,
                'from_id': v_from_id,
                'to_id': v_to_id,
                'geoid': geoid,
            })

    def _has_routes(self, row):
        """
        Фильтруем только непустые выдачи
        """
        try:
            t_types_counts = row.get('t_type_counts')
            if t_types_counts:
                t_types_counts_dict = json.loads(t_types_counts)
                if len(t_types_counts_dict) > 0:
                    return True

            return False

        except Exception:
            return False

    def routes_variants(self, from_key, from_type, from_id, to_key, to_type, to_id):
        """Варианты связей для заданных from и to. """

        yield from_key, to_key

        # для городов суммируем все роутсы их станций
        sett_from_key = None
        if from_type == 's':
            sett_id = self.stations_settlements.get(from_id)
            if sett_id:
                sett_from_key = 'c{}'.format(sett_id)

        sett_to_key = None
        if to_type == 's':
            sett_id = self.stations_settlements.get(to_id)
            if sett_id:
                sett_to_key = 'c{}'.format(sett_id)

        if sett_from_key:
            yield sett_from_key, to_key
            if sett_to_key and sett_to_key != sett_from_key:
                yield sett_from_key, sett_to_key

        if sett_to_key:
            yield from_key, sett_to_key

    @staticmethod
    def reduce_count(key, rows_iter):
        result = dict(key)
        result['total'] = sum(1 for _ in rows_iter)
        yield result


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--ytworkpath')
    parser.add_argument('--days', default=20, type=int)
    parser.add_argument('--cluster')
    parser.add_argument('--token')
    parser.add_argument('--ytsearchpath')
    args = parser.parse_args()

    with open(STATIONS_SETTLEMENTS_FILE_NAME, "r") as f:
        stations_settlements = json.loads(f.read())
    if not stations_settlements:
        raise Exception

    client = create_client(proxy=args.cluster, token=args.token)
    source_tables = StatTablesHelper(client, args.ytsearchpath).get_dates_back(args.days)
    suggests_stat = SearchStatRunner(client, source_tables, stations_settlements, args.ytworkpath)
    stat_table, routes_table = suggests_stat.run()
