# -*- encoding: utf-8 -*-
import logging
from collections import namedtuple

from cached_property import cached_property
from contextlib2 import closing
from datetime import datetime, timedelta

from pathlib2 import Path  # noqa

from travel.avia.dump_data.lib.model_classes.base_model import BaseModel
from travel.avia.dump_data.lib.mysql_connector import MysqlConnector  # noqa
from travel.library.python.dicts import file_util

FIRST_DAY_DELTA = 2
LAST_DAY_DELTA = 94
YT_MIN_PRICE_TABLE = '//home/avia/data/wizard/min-price-365d'


PriceKey = namedtuple('PriceKey', (
    'from_city_id',
    'to_city_id',
    'national_version',
    'currency_code',
    'date_forward',
))


def _parse_date(dt):
    return datetime.strptime(dt, '%Y-%m-%d').date()


class MinPrice(namedtuple('MinPrice', [
    'departure_settlement_id',
    'arrival_settlement_id',
    'national_version',
    'currency_id',
    'date_forward',
    'date_backward',
    'price',
    'passengers',
])):
    def get_pk(self):
        # type: () -> PriceKey
        return PriceKey(
            self.departure_settlement_id,
            self.arrival_settlement_id,
            self.national_version,
            self.currency_id,
            self.date_forward,
        )

    @classmethod
    def from_db_row(cls, row):
        return MinPrice(
            row['departure_settlement_id'],
            row['arrival_settlement_id'],
            row['national_version'],
            row['currency_id'],
            row['date_forward'],
            row['date_backward'],
            row['price'],
            row['passengers'],
        )

    @classmethod
    def from_yt_row(cls, row, currency_mapper):
        return MinPrice(
            int(row['departure_settlement_id']),
            int(row['arrival_settlement_id']),
            row['national_version'],
            currency_mapper[row['currency__code']],
            _parse_date(row['date_forward']),
            None,
            float(row['price']),
            row['passengers'],
        )

    def to_proto(self, proto_model):
        proto = proto_model()

        proto.DepartureSettlementID = self.departure_settlement_id
        proto.ArrivalSettlementID = self.arrival_settlement_id
        proto.NationalVersion = self.national_version
        proto.CurrencyID = self.currency_id
        proto.DateForward = self.date_forward.strftime('%Y-%m-%d')
        proto.DateBackward = self.date_backward.strftime('%Y-%m-%d') if self.date_backward else ''
        proto.Price = self.price
        proto.Passengers = self.passengers

        return proto


class MinPriceModel(BaseModel):
    def __init__(self, name, connector, yt_client, proto_model):
        self._name = name
        self.connector = connector  # type: MysqlConnector
        self.yt_client = yt_client
        self.proto_model = proto_model
        self._prices = {}

        today = datetime.today().date()
        self.date_from = today + timedelta(days=FIRST_DAY_DELTA)
        self.date_to = today + timedelta(days=LAST_DAY_DELTA)

    @property
    def name(self):
        return self._name

    @cached_property
    def currency_by_code(self):
        with closing(self.connector.get_connection()) as connection:
            with closing(connection.cursor()) as cursor:
                cursor.execute('select id, code from avia_currency')
                return {row['code']: int(row['id']) for row in cursor}

    def dump_into_directory(self, directory):
        # type: (Path) -> None
        file_name = directory / self.get_output_file_name()

        with closing(self.connector.get_connection()) as connection:
            with closing(connection.cursor()) as cursor:
                self._fetch_prices_from_mysql(cursor)

        self._fetch_prices_from_yt()

        with open(str(file_name), 'wb') as file:
            self._dump_into_file(file)

        logging.info('Write %s', str(file_name))

    def _fetch_prices_from_mysql(self, cursor):
        query = self._get_query()
        logging.debug('Execute query: %s', query)

        cursor.execute(query)
        for row in iter_counter(cursor, 'Process mysql rows: %s'):
            self._store_price(MinPrice.from_db_row(row))

    def _fetch_prices_from_yt(self):
        try:
            table = self.yt_client.TablePath(YT_MIN_PRICE_TABLE)
            for row in iter_counter(self.yt_client.read_table(table), 'Process yt rows: %s'):
                try:
                    price = MinPrice.from_yt_row(row, self.currency_by_code)
                except Exception:
                    logging.exception('Parsing YT min price exception for value: %s', row)
                    continue

                if price.get_pk() in self._prices:
                    continue

                if price.passengers == '1_0_0' and self.date_from <= price.date_forward <= self.date_to:
                    self._store_price(price)
        except Exception:
            logging.exception('Exception on getting yt min prices')

    def _get_query(self):
        query = """
            select
                departure_settlement_id,
                arrival_settlement_id,
                national_version,
                currency_id,
                date_forward,
                date_backward,
                price,
                passengers
            from www_minprice
            where date_forward between '{date_from}' and '{date_to}'
        """

        return query.format(date_from=self.date_from, date_to=self.date_to)

    def _store_price(self, price):
        # type: (MinPrice) -> None

        pk = price.get_pk()
        if (
                pk not in self._prices
                or price.price < self._prices[pk].price
        ):
            self._prices[pk] = price

    def _dump_into_file(self, file):
        count_row = 0
        for price in iter_counter(self._prices.itervalues(), 'Process proto rows: %s'):
            count_row += 1
            proto = price.to_proto(self.proto_model)
            file_util.write_binary_string(file, proto.SerializeToString())

        logging.info('Fetch %s rows for %s reference', count_row, self.name)


def iter_counter(iterable, msg, pack_size=100000):
    counter = 0
    for i in iterable:
        counter += 1
        if counter % pack_size == 0:
            logging.debug(msg, counter)
        yield i
