# coding: utf-8

import json
import logging
import os
import tempfile
from datetime import datetime, timedelta
from functools import reduce

import yt.wrapper
from yt.wrapper import YsonFormat

log = logging.getLogger(__name__)

CURRENCY_RATES_FILE_NAME = os.path.join(tempfile.gettempdir(), 'currency_rates.json')


class YtMinAviaPrice:
    def __init__(self, currency_rates, days_ago_from, days_ago_to,
                 source_dir, destination_dir, tmp_dir,
                 token, proxy):
        self.days_ago_from = days_ago_from
        self.days_ago_to = days_ago_to
        self.date_str = (datetime.today() - timedelta(days=days_ago_from)).strftime("%Y-%m-%d")
        self.currency_rates = currency_rates
        self.res_table = os.path.join(destination_dir, self.date_str)
        self.source_dir = source_dir
        self.destination_dir = destination_dir
        self.tmp_dir = tmp_dir
        self.token = token
        self.proxy = proxy
        self.tariffs_table = []

    def yt_init(self):
        yt.wrapper.update_config(
            {
                "token": self.token,
                "proxy": {'url': self.proxy}
            }
        )

    def yt_find_tariffs_table(self):
        for day_ago in range(self.days_ago_from, self.days_ago_to):
            table_path = os.path.join(self.source_dir,
                                      (datetime.today() - timedelta(days=day_ago)).strftime("%Y-%m-%d"))
            if yt.wrapper.exists(table_path):
                self.tariffs_table.append(table_path)
            else:
                log.warning('Path %s does not exist', table_path)

    def start(self):
        self.yt_init()
        self.yt_find_tariffs_table()
        log.info('tables: %s', str(self.tariffs_table))
        if not self.tariffs_table:
            return
        yt.wrapper.mkdir(self.tmp_dir, recursive=True)
        yt.wrapper.mkdir(self.destination_dir, recursive=True)
        with yt.wrapper.TempTable(self.tmp_dir, 'min_prices_filtered') as filtered_table:
            with yt.wrapper.TempTable(self.tmp_dir, 'min_prices_sorted') as sorted_table:
                log.info('map run')
                yt.wrapper.run_map(self.simple_runner,
                                   self.tariffs_table,
                                   filtered_table,
                                   output_format=YsonFormat(
                                       control_attributes_mode="row_fields",
                                       table_index_column="@table_index")
                                   )
                log.info('sort run')
                yt.wrapper.run_sort(filtered_table,
                                    sorted_table,
                                    sort_by=['route_uid', 'date_forward',
                                             'object_from_id', 'object_to_id', 'class'])
                log.info('reduce run')
                yt.wrapper.run_reduce(UniqueReducer(),
                                      sorted_table,
                                      self.res_table,
                                      reduce_by=['route_uid', 'date_forward',
                                                 'object_from_id', 'object_to_id', 'class'])
                log.info('run done')

    def get_row_price(self, row, price_field_name, num_adults):
        price_str = row[price_field_name].split(' ')
        try:
            price = float(price_str[0])
        except:
            return None

        if len(price_str) > 1:
            currency = price_str[1]
            if currency in self.currency_rates:
                price = price * self.currency_rates[currency]
            elif currency not in ['RUR', 'RUB']:
                return None
        # числом детей пренебрегаем, завышение цены не страшно
        return str(round(price / num_adults, 1))

    @staticmethod
    def get_row_class(row, price_name):
        class_str = price_name.split('_')
        if len(class_str) > 1:
            return class_str[1]
        if 'class' in row:
            return row['class']
        return ''

    @staticmethod
    def get_row_seats(row, class_name):
        if 'seats' in row:
            return row['seats']
        seats_field = 'class_{}_seats'.format(class_name)
        if seats_field in row:
            return row[seats_field]
        return None

    def simple_runner(self, row):
        def has_field(raw, field_name):
            return field_name in raw and raw[field_name]

        if '@table_index' in row:
            del row['@table_index']
        if (
            (has_field(row, 'date_forward') and len(row['date_forward']) == 10) and
            (not has_field(row, 'date_backward') or len(row['date_backward']) != 10) and
            'timestamp' in row
        ):
            if (has_field(row, 'route_uid') and ';' not in row['route_uid'] and 'bla-bla-car' != row['route_uid']):
                try:
                    num_adults = 1
                    if has_field(row, 'adults'):
                        num_adults = int(row['adults'])
                    if num_adults > 0:
                        row['adults'] = str(num_adults)
                        price_field_names = [
                            key for key in row.keys() if key.endswith('price') and row[key]
                        ]
                        if len(price_field_names) > 0:
                            price = self.get_row_price(row, price_field_names[0], num_adults)
                            row['class'] = self.get_row_class(row, price_field_names[0])
                            seats = self.get_row_seats(row, row['class'])
                            if seats:
                                row['seats'] = seats
                            if 'key' not in row:
                                row['key'] = ''
                            if 'object_from_type' not in row:
                                row['object_from_type'] = 'Station'
                            if 'object_to_type' not in row:
                                row['object_to_type'] = 'Station'
                            if price:
                                row['price'] = price
                                if all([field in row for field in UniqueReducer().fields]):
                                    int(row['object_from_id'])
                                    int(row['object_to_id'])
                                    yield row
                except ValueError:
                    pass
                except UnicodeDecodeError:
                    pass

    def save_dump(self, file_name, dump_table=None):
        if dump_table is None:
            dump_table = self.res_table
        with open(file_name, 'w') as dump:
            for row in yt.wrapper.read_table(dump_table, format='dsv', raw=True):
                dump.write(row)


class UniqueReducer(object):
    def __init__(self):
        self.fields = ['route_uid', 'date_forward',
                       'object_from_id', 'object_from_type', 'object_to_id', 'object_to_type',
                       'type', 'timestamp', 'price', 'class', 'seats', 'key']

    def __call__(self, key, recs):
        best_price = reduce(lambda a, b:
                            a if (a['timestamp'], -float(a['price'])) > (b['timestamp'], -float(b['price'])) else b,
                            recs)
        res = dict()
        for f in self.fields:
            res[f] = best_price[f]
        yield res


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--days_ago_from', type=int)
    parser.add_argument('--days_ago_to', type=int)
    parser.add_argument('--source_dir')
    parser.add_argument('--destination_dir')
    parser.add_argument('--tmp_dir')
    parser.add_argument('--token')
    parser.add_argument('--proxy')
    args = parser.parse_args()

    with open(CURRENCY_RATES_FILE_NAME, "r") as f:
        currency_rates = json.loads(f.read())

    runner = YtMinAviaPrice(currency_rates, args.days_ago_from, args.days_ago_to, args.source_dir, args.destination_dir,
                            args.tmp_dir, args.token, args.proxy)
    runner.start()
