import logging
import os
import random
from collections import defaultdict
from datetime import datetime, timedelta

from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment

from sandbox.projects.avia.base import AviaBaseTask

SECONDS_IN_HOUR = 3600
LOG_ROOT_DIR = '//home/logfeller/logs/avia-json-redir-log/30min'


PARTNER_BLACKLIST = {
    'charterok',
    'nabortu',
    'pilotua',
    'davs',
    'dohop',
    'justtravel',
    'tez_tour',
}


def safe_sample(objects, size):
    if size > len(objects):
        return objects

    return random.sample(objects, size)


def get_hours_from_td(td):
    return int(td.total_seconds() // SECONDS_IN_HOUR)


class AviaReference(object):
    def __init__(self, yt_client):
        import yt.wrapper as yt
        self.settlements = self._build_dict_from_yt_table(
            yt_client, '//home/rasp/reference/settlement', 'id', 'title',
            format=yt.JsonFormat(),
        )

        self.stations = self._build_dict_from_yt_table(
            yt_client, '//home/rasp/reference/station', 'id', 'title',
            format=yt.JsonFormat(),
        )

        self.partners = self._build_dict_from_yt_table(
            yt_client, '//home/rasp/reference/partner', 'billing_client_id', 'code',
            format=yt.JsonFormat(),
        )

    @staticmethod
    def _build_dict_from_yt_table(yt_client, path, key_column, value_column, format):
        return {
            r[key_column]: r[value_column]
            for r in yt_client.read_table(path, format=format)
        }


class GenerateAviaTolokaReviseTasks(AviaBaseTask):
    """
    Generate deepling checking toloka tasks
    See https://st.yandex-team.ru/RASPTICKETS-13036
    """

    class Requirements(sdk2.Task.Requirements):
        cores = 1
        ram = 8192

        class Caches(sdk2.Requirements.Caches):
            pass  # We do not need caches

        environments = (
            PipEnvironment('yandex-yt', version='0.10.8'),
            PipEnvironment('pytz'),
        )

    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.Group('Task generation settings'):
            size = sdk2.parameters.Integer('Number of tasks for each partner', default=30, required=True)
            nv = sdk2.parameters.String('Nation version', default='ru', required=True)
            currency = sdk2.parameters.String('Currency', default='RUR')

        with sdk2.parameters.Group('Map reduce settings') as mr_block:
            mr_cluster = sdk2.parameters.String('MapReduce cluster', default='hahn', required=True)
            mr_user = sdk2.parameters.String('MapReduce user', required=True)
            mr_dir = sdk2.parameters.String('Directory', required=True, default='//home/avia/reports/avia-revise-deeplinks')

    _yt_client = None

    TIMES = [timedelta(hours=x) for x in [4, 8, 12, 24]]
    MIN_TIME = min(TIMES)

    def get_output_table(self, date):
        return os.path.join(self.Parameters.mr_dir, date.strftime('%Y-%m-%d'))

    def _get_point_from_record(self, record, direction):
        point_column = '{}_airport_id'.format(direction)
        settlement_column = '{}_settlement_id'.format(direction)

        return record.get(point_column) or record.get(settlement_column)

    def get_records(self):
        import pytz
        import yt.wrapper as yt
        yt_client = self._get_yt_client()

        records_by_partner = defaultdict(list)
        start_time = pytz.UTC.localize(datetime.utcnow())
        logging.info('Now (in UTC): %s', start_time.strftime('%Y-%m-%d %H:%M:%S'))

        moscow_tz = pytz.timezone('Europe/Moscow')

        for table in yt_client.search(LOG_ROOT_DIR, node_type='table'):
            table_path = yt.TablePath(
                os.path.join(LOG_ROOT_DIR, table),
            )
            logging.info('Read %s', str(table_path))
            for record in yt_client.read_table(table_path, format=yt.JsonFormat()):
                fetch_time = pytz.UTC.localize(datetime.utcfromtimestamp(record['unixtime'])).astimezone(moscow_tz)
                if start_time - fetch_time > self.MIN_TIME:
                    continue

                if record['national_version'] != self.Parameters.nv or record['offer_currency'] != self.Parameters.currency:
                    continue

                partner_id = self.reference.partners.get(record['billing_client_id'])
                if partner_id is None:
                    logging.warning('Partner not found: %d', record['billing_client_id'])
                    continue

                if partner_id in PARTNER_BLACKLIST:
                    continue

                new_record = {
                    'partner': partner_id,
                    'fetch_time': fetch_time.strftime('%Y-%m-%d %H:%M:%S'),
                    'point_from': self._get_point_title(record['fromId']),
                    'point_to': self._get_point_title(record['toId']),
                    'departure_date': record['when'],
                    'arrival_date': record['return_date'],
                    'adults': record['adult_seats'],
                    'children': record['children_seats'],
                    'infants': record['infant_seats'],
                    'price': record['original_price'],
                    'url': record['url'],
                }

                new_record.update({
                    'check_{}_hours'.format(get_hours_from_td(delta)): (fetch_time + delta).strftime('%Y-%m-%dT%H:%M:%S')
                    for delta in self.TIMES
                })

                records_by_partner[partner_id].append(new_record)

        return records_by_partner

    def _get_point_title(self, point_key):
        if point_key[0] == 'c':
            return self.reference.settlements[point_key]

        if point_key[0] == 's':
            return self.reference.stations[point_key]

        raise ValueError('Unknown point type: {}'.format(point_key))

    def _get_yt_client(self):
        if self._yt_client is None:
            import yt.wrapper as yt
            self._yt_client = yt.YtClient(config={
                'token': sdk2.Vault.data(self.Parameters.mr_user, 'YT_TOKEN'),
                'proxy': {'url': self.Parameters.mr_cluster},
            })

        return self._yt_client

    def on_execute(self):
        logging.info('Start')
        logging.info('Build reference')
        self.reference = AviaReference(self._get_yt_client())
        records_by_partner = self.get_records()

        records_to_write = []
        for records in records_by_partner.itervalues():
            records_to_write.extend(safe_sample(records, self.Parameters.size))

        logging.info('Write %d records to YT', len(records_to_write))
        self._write_records(records_to_write)

    def _write_records(self, records):
        import yt.wrapper as yt
        output_table = self.get_output_table(datetime.now())
        yt_client = self._get_yt_client()

        if not yt_client.exists(output_table):
            self._create_output_table(output_table)

        yt_client.write_table(output_table, records, format=yt.JsonFormat())

    def _create_output_table(self, path):
        self._get_yt_client().create(
            'table',
            path,
            recursive=True,
        )
