# encoding: utf8
import abc
import logging
import itertools
import random
import urllib
from collections import Counter
from datetime import datetime, timedelta

from sandbox import sdk2
from sandbox.sandboxsdk import environments

from sandbox.projects.avia.base import AviaBaseTask
from sandbox.projects.avia.lib import logs, yt_helpers as yth, destination_cache
from sandbox.projects.avia.lib.point_to_iata import PointToIATACache


RUSSIA_ID = 'l225'
CIS_COUNTRIES = {  # Russia in not included
    'l149',  # Беларусь
    'l159',  # Казахстан
    'l167',  # Азербайджан
    'l168',  # Армения
    'l169',  # Грузия
    'l170',  # Туркмения
    'l171',  # Узбекистан
    'l187',  # Украина
    'l207',  # Киргизия
    'l208',  # Молдова
    'l209',  # Таджикистан
}
CIS_COUNTRIES_WITH_RUSSIA = CIS_COUNTRIES | {RUSSIA_ID}


def safe_random_sample(lst, size):
    if len(lst) <= size:
        return lst

    return random.sample(lst, size)


class PointByCountryFilter(object):
    def __init__(self, point_to_country):
        self._point_to_country = point_to_country

    def __call__(self, from_id, to_id):
        from_country = self._point_to_country[from_id]
        to_country = self._point_to_country[to_id]

        return self.check_by_countries(from_country, to_country)

    def check_by_countries(self, from_country, to_country):
        raise NotImplementedError()


class VVLFilter(PointByCountryFilter):
    """ Check if both points are in Russia """
    def check_by_countries(self, from_country, to_country):
        return from_country == RUSSIA_ID and to_country == RUSSIA_ID


class CISFilter(PointByCountryFilter):
    """ Check if one point is in Russia and another one is in CIS but not in Russia """
    def check_by_countries(self, from_country, to_country):
        return (from_country == RUSSIA_ID and to_country in CIS_COUNTRIES) or (to_country == RUSSIA_ID and from_country in CIS_COUNTRIES)


class MVLFilter(PointByCountryFilter):
    """ Check if one point is in Russia and anoter one outside of CIS """
    def check_by_countries(self, from_country, to_country):
        return (from_country == RUSSIA_ID and to_country not in CIS_COUNTRIES_WITH_RUSSIA) or (to_country == RUSSIA_ID and from_country not in CIS_COUNTRIES_WITH_RUSSIA)


class AbroadFilter(PointByCountryFilter):
    """ Check if both points are outside iof CIS """
    def check_by_countries(self, from_country, to_country):
        return from_country not in CIS_COUNTRIES_WITH_RUSSIA and to_country not in CIS_COUNTRIES_WITH_RUSSIA


class BucketDescription(object):
    __slots__ = ['name', 'filter', 'size']

    def __init__(self, name, filter, size):
        self.name = name
        self.filter = filter
        self.size = size


class IRivalRecordGenerator(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self, name, destination_title_by_id):
        self.name = name
        self._destination_title_by_id = destination_title_by_id

    def do(self, redirect_type, dates, directions):
        for (forward_date, backward_date), (from_id, to_id) in itertools.product(dates, directions):
            yield {
                'rival': self.name,
                'fromId': from_id,
                'toId': to_id,
                'forward_date': self.format_date_to_output(forward_date),
                'backward_date': self.format_date_to_output(backward_date),
                'url': self.create_url(from_id, to_id, forward_date, backward_date),
                'type': redirect_type,

                'from_name': self._destination_title_by_id.get(from_id),
                'to_name': self._destination_title_by_id.get(to_id),
                'class': 'economy',
                'adult_seats': 1,
                'children_seats': 0,
                'infant_seats': 0,
            }

    @abc.abstractmethod
    def create_url(self, from_id, to_id, forward_date, backward_date):
        raise NotImplementedError()

    def format_date_to_output(self, date):
        return date.strftime('%Y-%m-%d') if date else None


class IRivalWithIATARecordGenerator(IRivalRecordGenerator):
    __metaclass__ = abc.ABCMeta

    def __init__(self, name, destination_title_by_id, point_reference):
        IRivalRecordGenerator.__init__(self, name, destination_title_by_id)
        self.point_reference = point_reference

    def create_url(self, from_id, to_id, forward_date, backward_date):
        from_iata = self.point_reference.get(from_id)
        to_iata = self.point_reference.get(to_id)
        return self.create_url_from_iata(from_iata, to_iata, forward_date, backward_date)

    @abc.abstractmethod
    def create_url_from_iata(self, from_iata, to_iata, forward_date, backward_date):
        raise NotImplementedError()


class AviasalesGenerator(IRivalWithIATARecordGenerator):
    def create_url_from_iata(self, from_iata, to_iata, forward_date, backward_date):
        return 'https://www.aviasales.ru/search/{from_iata}{forward_date}{to_iata}{backward_date}100'.format(
            from_iata=from_iata.lower(),
            to_iata=to_iata.lower(),
            forward_date=forward_date.strftime('%d%m'),
            backward_date=backward_date.strftime('%d%m') if backward_date else '',
        )


class SkyscannerGenerator(IRivalWithIATARecordGenerator):
    def create_url_from_iata(self, from_iata, to_iata, forward_date, backward_date):
        if backward_date:
            return 'https://www.skyscanner.ru/transport/flights/{from_iata}/{to_iata}/{forward_date}/{backward_date}/?adults=1&children=0&adults=0'.format(
                from_iata=from_iata.lower(),
                to_iata=to_iata.lower(),
                forward_date=self._format_date(forward_date),
                backward_date=self._format_date(backward_date),
            )

        return 'https://www.skyscanner.ru/transport/flights/{from_iata}/{to_iata}/{forward_date}/?adults=1&children=0&adults=0'.format(
            from_iata=from_iata.lower(),
            to_iata=to_iata.lower(),
            forward_date=self._format_date(forward_date),
        )

    def _format_date(self, date):
        return date.strftime('%Y%m%d')


class KayakGenerator(IRivalWithIATARecordGenerator):
    def create_url_from_iata(self, from_iata, to_iata, forward_date, backward_date):
        if backward_date:
            return 'https://www.kayak.ru/flights/{from_iata}-{to_iata}/{forward_date}/{backward_date}/1adults?sort=bestflight_a'.format(
                from_iata=from_iata.lower(),
                to_iata=to_iata.lower(),
                forward_date=self._format_date(forward_date),
                backward_date=self._format_date(backward_date),
            )

        return 'https://www.kayak.ru/flights/{from_iata}-{to_iata}/{forward_date}/1adults?sort=bestflight_a'.format(
            from_iata=from_iata.lower(),
            to_iata=to_iata.lower(),
            forward_date=self._format_date(forward_date),
        )

    def _format_date(self, date):
        return date.strftime('%Y-%m-%d')


class YandexGenerator(IRivalRecordGenerator):
    def create_url(self, from_id, to_id, forward_date, backward_date):
        return 'https://avia.yandex.ru/search/result?fromId={from_id}&toId={to_id}&when={forward_date}&return_date={backward_date}'.format(
            from_id=from_id,
            to_id=to_id,
            forward_date=self._format_date(forward_date),
            backward_date=self._format_date(backward_date) if backward_date else '',
        )

    def _format_date(self, date):
        return date.strftime('%Y-%m-%d')


class GoogleFlightsGenerator(IRivalWithIATARecordGenerator):
    def create_url_from_iata(self, from_iata, to_iata, forward_date, backward_date):
        get_params = {
            'gl': 'ru',  # страна
            'hl': 'ru',  # язык
        }
        flt = '.'.join((from_iata, to_iata, self._format_date(forward_date)))
        hash_params = {
            'c': 'RUB',  # валюта
        }

        if backward_date is None:
            hash_params['tt'] = 'o'
        else:
            flt = '{}*{}'.format(
                flt,
                '.'.join((to_iata, from_iata, self._format_date(backward_date))),
            )

        return 'https://www.google.com/flights?{get_params}#flt={flt};{hash_params}'.format(
            get_params=urllib.urlencode(get_params),
            flt=flt,
            hash_params=';'.join('{}:{}'.format(key, value) for key, value in hash_params.iteritems()),
        )

    def _format_date(self, date):
        return date.strftime('%Y-%m-%d')


class RivalTableGenrator(object):
    def __init__(self, rival_generators, date_forward_shift, date_backward_shift):
        self.rival_generators = rival_generators
        self.date_forward_shift = date_forward_shift
        self.date_backward_shift = date_backward_shift

    def do(self, today, directions_by_type):
        dates = self.generate_dates(today)
        for rival_generator in self.rival_generators:
            for _type, directions in directions_by_type.iteritems():
                for record in rival_generator.do(_type, dates, directions):
                    yield record

    def generate_dates(self, today):
        return [
            (today + self.date_forward_shift, None),
            (today + self.date_forward_shift, today + self.date_backward_shift),
        ]


class AviaGenerateDirectionsForToloka(AviaBaseTask):
    """
    Generating directions for tolokers to compare prices
    based on redirects number
    """

    class Requirements(sdk2.Task.Requirements):
        environments = (
            environments.PipEnvironment('yandex-yt', version='0.10.8'),
            environments.PipEnvironment('ujson'),
            environments.PipEnvironment('yandex-yt-yson-bindings-skynet', version='0.3.32-0'),
            environments.PipEnvironment('raven'),
        )

        # https://wiki.yandex-team.ru/sandbox/clients/#client-tags-multislot
        cores = 1  # exactly 1 core
        ram = 8192  # 8GiB or less

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.Group('Task settings'):
            day_count = sdk2.parameters.Integer('Day count', default=7, required=True)
            popular_directions = sdk2.parameters.Integer('Number of popular directions', default=20, required=True)
            random_directions = sdk2.parameters.Integer('Number of random directions', default=20, required=True)
            vvl_directions = sdk2.parameters.Integer('Number of Russia - Russia directions', default=20, required=True)
            cis_directions = sdk2.parameters.Integer('Number of Russia - CIS directions', default=20, required=True)
            mvl_directions = sdk2.parameters.Integer('Number of Russia - out CIS directions', default=20, required=True)
            abroad_directions = sdk2.parameters.Integer('Number of out CIS - out CIS directions', default=20, required=True)
            redirect_threshold = sdk2.parameters.Integer('Minum number of redirects to get to random sample', default=100, required=True)
            output_prefix = sdk2.parameters.String('Output directory', default='//home/avia/toloka/rivals', required=True)
            yt_proxy = sdk2.parameters.String('YT cluster', default='hahn', required=True)
            fixed_path = sdk2.parameters.String('Fixes YT path of renewable table', default='parcing_directions')
            reference_yt_path = sdk2.parameters.String(
                'Reference YT path', required=True, default_value='//home/rasp/reference'
            )

        with sdk2.parameters.Group('Vault settings'):
            vaults_owner = sdk2.parameters.String('Vault owner', required=True)
            yt_token_vault_name = sdk2.parameters.String('YT token vault name', default='YT_TOKEN', required=True)

    REDIR_LOG_DIRECTORY = '//home/avia/logs/avia-redir-balance-by-day-log'
    SEARCH_LOG_DIRECTORY = '//home/avia/logs/avia-users-search-log'
    _yt_client = None

    @property
    def yt_client(self):
        if self._yt_client is None:
            import yt.wrapper as yt
            self._yt_client = yt.YtClient(
                proxy=self.Parameters.yt_proxy,
                token=sdk2.Vault.data(self.Parameters.vaults_owner, self.Parameters.yt_token_vault_name)
            )

        return self._yt_client

    def build_redirect_rating(self, point_filter):
        return Counter(
            (r['FROMID'], r['TOID'])
            for table in yth.last_logs_tables(self.yt_client, self.REDIR_LOG_DIRECTORY, self.Parameters.day_count)
            for r in self.yt_client.read_table(table)
            if r['FILTER'] == 0 and r['FROMID'] and r['TOID'] and point_filter(r['FROMID']) and point_filter(r['TOID'])
        )

    def build_search_rating(self, point_filter):
        return Counter(
            (r['fromId'], r['toId'])
            for table in yth.last_logs_tables(self.yt_client, self.SEARCH_LOG_DIRECTORY, self.Parameters.day_count)
            for r in self.yt_client.read_table(table)
            if r['fromId'] and r['toId'] and point_filter(r['fromId']) and point_filter(r['toId'])
        )

    def get_point_to_iata(self):
        return PointToIATACache(self.yt_client)

    def get_point_to_country(self):
        import yt.wrapper as yt

        point_to_country = {}
        for path in (yt.ypath_join(self.Parameters.reference_yt_path, model) for model in ('settlement', 'station')):
            point_to_country.update({
                r['id']: r['country_id']
                for r in self.yt_client.read_table(path)
            })

        return point_to_country

    def on_prepare(self):
        self._destination_cache = destination_cache.DestinationCache(self.yt_client, self.Parameters.reference_yt_path)

        super(AviaGenerateDirectionsForToloka, self).on_prepare()

    def on_execute(self):
        logs.configure_logging(logs.get_sentry_dsn(self))
        point_to_iata = self.get_point_to_iata()

        point_to_country = self.get_point_to_country()

        # We need only to and from cities with IATA codes
        point_filter = lambda point_code: point_code[0] == 'c' and point_to_iata.get(point_code)

        top_directions_by_search = {r for r, _ in self.build_search_rating(point_filter).most_common(self.Parameters.popular_directions)}
        logging.info('Got top directions')

        redirect_rating = self.build_redirect_rating(point_filter)
        big_directions_by_redirects = [
            r for r, c in redirect_rating.iteritems()
            if c > self.Parameters.redirect_threshold and r not in top_directions_by_search
        ]

        directions_by_redirects = [r for r, c in redirect_rating.most_common()]

        logging.info('%r', directions_by_redirects)

        buckets = {
            'popular': top_directions_by_search,
        }

        used_directions = set()
        bucket_descriptions = [
            BucketDescription('vvl', VVLFilter(point_to_country), self.Parameters.vvl_directions),
            BucketDescription('cis', CISFilter(point_to_country), self.Parameters.cis_directions),
            BucketDescription('mvl', MVLFilter(point_to_country), self.Parameters.mvl_directions),
            BucketDescription('abroad', AbroadFilter(point_to_country), self.Parameters.abroad_directions),
            BucketDescription('random', lambda x, y: True, self.Parameters.random_directions),  # All records can be in random bucket
        ]

        for bucket_description in bucket_descriptions:
            logging.info('Generating directions for bucket %s', bucket_description.name)
            possible_directions = [r for r in big_directions_by_redirects if bucket_description.filter(r[0], r[1]) and r not in used_directions]
            if len(possible_directions) < bucket_description.size:
                logging.info('No big directions for bucket %s. Add from all directions', bucket_description.name)
                possible_directions.extend(itertools.islice(
                    (r for r in directions_by_redirects if bucket_description.filter(r[0], r[1]) and r not in used_directions),
                    0, bucket_description.size - len(possible_directions)
                ))

            random_directions = safe_random_sample(possible_directions, bucket_description.size)

            buckets[bucket_description.name] = random_directions
            used_directions.update(random_directions)
            logging.info('Got %d from %d possible', len(random_directions), len(possible_directions))

        today = datetime.now()
        rival_table_generator = RivalTableGenrator(
            rival_generators=[
                YandexGenerator('yandex', self._destination_cache.title_by_id),
                AviasalesGenerator('aviasales', self._destination_cache.title_by_id, point_to_iata),
                KayakGenerator('kayak', self._destination_cache.title_by_id, point_to_iata),
                SkyscannerGenerator('skyscanner', self._destination_cache.title_by_id, point_to_iata),
                GoogleFlightsGenerator('google', self._destination_cache.title_by_id, point_to_iata),
            ],
            date_forward_shift=timedelta(days=2),
            date_backward_shift=timedelta(days=9),
        )

        rival_records = rival_table_generator.do(today, buckets)
        self.write_result(rival_records)

    def write_result(self, rival_records):
        import yt.wrapper as yt

        logging.info('Writing result')
        today = datetime.now()
        output_table = yt.ypath_join(self.Parameters.output_prefix, today.strftime('%Y-%m-%d'))
        logging.info('Output table: %s', output_table)

        with self.yt_client.Transaction():
            if self.yt_client.exists(output_table):
                self.yt_client.remove(output_table)

            self.yt_client.create(
                'table',
                output_table,
                attributes={
                    'optimize_for': 'scan',
                    'schema': [
                        {'type': 'string', 'name': 'rival'},
                        {'type': 'string', 'name': 'fromId'},
                        {'type': 'string', 'name': 'toId'},
                        {'type': 'string', 'name': 'forward_date'},
                        {'type': 'string', 'name': 'backward_date'},
                        {'type': 'string', 'name': 'url'},
                        {'type': 'string', 'name': 'type'},

                        {'type': 'string', 'name': 'from_name'},
                        {'type': 'string', 'name': 'to_name'},
                        {'type': 'string', 'name': 'class'},
                        {'type': 'uint8', 'name': 'adult_seats'},
                        {'type': 'uint8', 'name': 'children_seats'},
                        {'type': 'uint8', 'name': 'infant_seats'},
                    ],
                },
                recursive=True,
            )

            self.yt_client.write_table(
                output_table,
                rival_records,
            )

            fixed_path = yt.ypath_join(self.Parameters.output_prefix, self.Parameters.fixed_path)
            logging.info('Copy table to %s', fixed_path)
            if self.yt_client.exists(fixed_path):
                self.yt_client.remove(fixed_path)

            self.yt_client.copy(output_table, fixed_path)
