import logging
import os
from datetime import datetime, timedelta

import yt.wrapper
from yql.api.v1.client import YqlClient


log = logging.getLogger(__name__)


class YtProxyLogs:
    def __init__(self, sources, dest_dir, token, proxy):
        self.date_str = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
        self.res_table = os.path.join(dest_dir, self.date_str)
        self.sources = sources
        self.dest_dir = dest_dir
        self.token = token
        self.proxy = proxy
        self.uniq_ids = set()
        self.results = {}

    def yt_init(self):
        yt.wrapper.update_config(
            {
                "token": self.token,
                "proxy": {'url': self.proxy}
            }
        )

    def yt_find_tariffs_table(self, src_name):
        tables = []
        for day_ago in range(1, 2):
            table_path = os.path.join(self.sources[src_name],
                                      (datetime.today() - timedelta(days=day_ago)).strftime("%Y-%m-%d"))
            if yt.wrapper.exists(table_path):
                tables.append(table_path)
            else:
                log.warning('Path %s does not exist', table_path)

        return tables

    def start(self):
        self.yt_init()
        pr_result_tables = self.yt_find_tariffs_table('result')
        pr_tariffs_tables = self.yt_find_tariffs_table('tariffs')

        log.info('tables: %s', str(pr_result_tables))

        # yt.wrapper.mkdir(self.tmp_dir, recursive=True)

        p = os.path.join(self.dest_dir, 'filtered_path_logs')  # таблица для логов поездов с процентом цен
        p_bad = os.path.join(self.dest_dir, 'filtered_path_logs_bad')  # таблица для логов с 0%  цен
        p_req = os.path.join(self.dest_dir, 'filtered_req_id')  # таблица для request_id
        p_req_tar = os.path.join(self.dest_dir, 'filtered_req_tar')  # отфильтрованные по request_id тарифы
        p_req_merge = os.path.join(self.dest_dir, 'filtered_req_merge')  # склеинные по request_id исходные запросы и полученные тарифы
        p_req_merge_bad = os.path.join(self.dest_dir, 'filtered_req_merge_bad')  # таблица в которой тольок неполные варианты (есть часть цен)
        p_merge_strange = os.path.join(self.dest_dir, 'strange')  # склеинные по request_id исходные запросы

        p_without_thread = os.path.join(self.dest_dir, 'without_thread')  # ответы train-api в сегментах которых нет thread

        log.info('map run')
        yt.wrapper.run_map(self.train_runner,
                           pr_result_tables,
                           p
                           )

        log.info('map run bad')
        yt.wrapper.run_map(self.only_bad_variants,
                           [p],
                           p_bad
                           )

        log.info('map run request_id')
        yt.wrapper.run_map(self.req_id_runner,
                           [p_bad],
                           p_req,
                           job_count=1
                           )

        log.info('load ids')
        self.load_ids(p_req)

        yt.wrapper.run_map(self.req_id_tariffs_runner,
                           pr_tariffs_tables,
                           p_req_tar,
                           )

        log.info('load results')
        self.load_results(p_bad)

        #
        # log.info('merge')
        # yt.wrapper.run_map(self.req_merge,
        #                    p_req_tar,
        #                    p_req_merge_bad,
        #                    job_count=1,
        #                    memory_limit=8096
        #                    )

        # сортировка перед reduce
        yt.wrapper.run_sort(p_req_tar, sort_by='request_id')

        log.info('reduce merge')
        yt.wrapper.run_reduce(self.req_merge_reduce,
                              p_req_tar,
                              p_req_merge,
                              reduce_by='request_id'

                              )

        log.info('load merge')
        res = self.load_merge(p_req_merge)

        log.info('save strange')
        self.write_found(p_merge_strange, res)

        log.info('save tariffs without thread')
        yt.wrapper.run_map(self.without_thread,
                           pr_tariffs_tables,
                           p_without_thread,
                           )

        log.info('run done')

    # оставляет только запросы с поездами
    # считает процент сегментов/вариантов пересадок с ценами
    def train_runner(self, row):
        if 'train' in row['transport_types'] and '2020-02-27' < row['when'] < '2021':
            segments_count = 0
            tariff_count = 0
            var_price_count = 0
            all_vars = 0
            for var in row['transfer_variants']:
                # не считаем украину
                if any(segment['station_from']['country'].get('id') == 187
                       or segment['station_to']['country'].get('id') == 187
                       for segment in var['segments']):
                    continue
                all_vars += 1
                segments_count += var['segments_count']
                seg_with_price = 0
                for segment in var.get('segments'):
                    if segment.get('tariffs'):
                        tariff_count += 1
                        seg_with_price += 1
                # цены должны быть в каждом сегменте варианта
                if seg_with_price == var['segments_count']:
                    var_price_count += 1

            if all_vars:
                # row.update({'tariff_percent': tariff_count / segments_count})
                row.update({'tariff_percent': var_price_count / all_vars})
                row['month'] = datetime.strptime(row['when'], '%Y-%m-%d').month * 1.0
                yield row

    # отфильтровывает только варианты без цен
    def only_bad_variants(self, row):
        if row['tariff_percent'] == 0 and row['status'] in ['done']:
            res = {k: row[k] for k in ['point_from', 'point_to', 'when', 'transfer_variants', 'request_id', 'iso_eventtime']}
            yield res

    # сохраняет request_id плохих вариантов в отдельную функцию
    @yt.wrapper.aggregator
    def req_id_runner(self, rows):
        ids = set()
        for row in rows:
            req_id = row['request_id']
            if req_id not in ids:
                ids.add(req_id)

        yield {'uniq_ids': list(ids)}

    # отфильтровывает запросы тарифов по request_id
    def req_id_tariffs_runner(self, row):
        if row['request_id'] in self.uniq_ids:
            res = {k: row[k] for k in ['point_from', 'point_to', 'when', 'transfer_variants_tariffs', 'request_id', 'status', 'poll', 'iso_eventtime']}
            yield res

    # нахождение тарифов train-api, в сегментах которых пустой thread
    def without_thread(self, row):
        for segment in row['transfer_variants_tariffs'].get('segments', []):
            if not segment['thread']:
                yield row

    # загружаем в память id из из таблицы
    def load_ids(self, table):
        for row in yt.wrapper.read_table(table, format='json', raw=False):
            self.uniq_ids = set(row['uniq_ids'])

    def load_results(self, table):
        for row in yt.wrapper.read_table(table, format='yson', raw=False):
            row['tariff_requests'] = []
            self.results[row['request_id']] = row

    # склеиваем итоговые ответы и запросы к train_api по ключу
    def req_merge_reduce(self, key, rows):
        yield {**self.results[key['request_id']], **{'tariff_requests': [r for r in rows]}}

    # @yt.wrapper.aggregator
    # def req_merge(self, rows):
    #     for row in rows:
    #         res = self.results[row['request_id']]
    #         res['tariff_requests'].append(row)
    #
    #     for k, v in self.results.items():
    #         yield v

    # склеить цены по ключам
    def load_merge(self, table):
        rows = []

        for row in yt.wrapper.read_table(table, format='yson', raw=False):
            res_ = []
            tariff_by_key = {}
            for req in row['tariff_requests']:
                for tariff in req['transfer_variants_tariffs']['segments']:
                    tariff_by_key[str((tariff['station_from']['id'], tariff['station_to']['id'], tariff['departure']))] = tariff

            if tariff_by_key:
                for var in row['transfer_variants']:
                    for segment in var['segments']:
                        try:
                            key = str((segment['station_from']['id'], segment['station_to']['id'], segment['departure']))
                        except Exception:
                            continue
                        t = tariff_by_key.get(key, None)
                        if t:
                            # res[row['request_id']].append({key: t})
                            res_.append({key: t})

            if res_:
                row['discovered'] = res_
                rows.append(row)
                row['tariff_requests'].sort(key=lambda x: x['iso_eventtime'])

        return rows

    # записать итоговые строчки в таблицу
    def write_found(self, table, data):
        yt.wrapper.write_table(table, data, raw=False)


# запуск yql скрипта из питона
def yql_play():
    token = os.environ.get('my_yql_token')
    client = YqlClient(db='hahn', token=token)

    request = client.query(
        """USE hahn;
        PRAGMA yt.InferSchema = '1';
        SELECT uniq_ids FROM concat('home/rasp/anded/filtered_req_id');
        """,
        syntax_version=1
    )
    request.run()

    for table in request.get_results():
        table.fetch_full_data()
        print('=== Table ===')
        print('--- Schema ---')
        for column_name, column_print_type in zip(table.column_names, table.column_print_types):
            print(column_name + '\t' + column_print_type)

        print('\n--- Data ---')
        for row in table.rows:
            print('\t'.join([str(cell) for cell in row]))


def main():
    sources = {
        'result': '//home/logfeller/logs/rasp-pathfinder-proxy-result-log/1d/',
        'tariffs': '//home/logfeller/logs/rasp-pathfinder-proxy-tariffs-log/1d/',
        'search': '//home/logfeller/logs/rasp-pathfinder-proxy-search-log/1d/',
    }

    destination_dir = '//home/rasp/anded'
    proxy = 'hahn.yt.yandex.net'
    token = os.environ.get('robot_rasp_token')

    runner = YtProxyLogs(sources, destination_dir, token, proxy)
    runner.start()


if __name__ == "__main__":
    main()

