import logging
import math
import operator
from collections import defaultdict
from datetime import datetime, timedelta

from sandbox import sdk2
from sandbox.projects.avia.lib.logs import configure_logging, get_sentry_dsn
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

from sandbox.projects.avia.base import AviaBaseTask
from sandbox.projects.avia.lib import (
    yql_helpers as yqlh,
    statface_helpers as sth,
)


DAY_BEFORE_YESTERDAY = datetime.today() - timedelta(days=2)


SERVICE_TO_ID = {
    'ticket': 1,
    'm_avia': 2,
    'api_avia': 3,
    'sovetnik': 4,
    '42': 5,
    'tours': 6,
    'welcome2018': 7,
    'rasp_api_public': 8,
    'rasp_mobile': 9,
    'rasp_morda': 10,
    'rasp_morda_backend': 11,
    'rasp_touch': 12,
    'wizard': 13,
    'email_subscriptions': 14,
    'yeah': 15,
    'avia-travel': 18,
    'mavia-travel': 19,
    'wizard_user': 998,
    'rasp': 999,
}


SERVICE_ID_TO_NAME = {
    name: _id
    for _id, name in SERVICE_TO_ID.iteritems()
}


def sort_by_timestamp(lst):
    return sorted(lst, key=operator.attrgetter('timestamp'))


class Query(object):
    __slots__ = [
        'from_id', 'to_id',
        'adults', 'children', 'infants',
        'forward_date', 'backward_date',
        'source', 'timestamp',
    ]

    def __init__(
        self, from_id, to_id, adults, children, infants, forward_date, backward_date, source, timestamp=None,
    ):
        self.from_id = from_id
        self.to_id = to_id
        self.adults = adults
        self.children = children
        self.infants = infants
        self.forward_date = forward_date
        self.backward_date = backward_date
        self.source = source
        self.timestamp = timestamp

    @classmethod
    def from_merged(cls, record):
        return cls(
            record['from_id'],
            record['to_id'],
            record.get('adults'),
            record.get('children'),
            record.get('infants'),
            cls._parse_date(record.get('forward_date')),
            cls._parse_date(record.get('backward_date')),
            record.get('service'),
            cls._parse_unixtime(record.get('unixtime')),
        )

    def key(self):
        return (
            self.from_id,
            self.to_id,
            self.adults,
            self.children,
            self.infants,
            self.forward_date,
            self.backward_date,
        )

    def __repr__(self):
        return '<Query key={}, unixtime={}'.format(
            '_'.join(map(str, self.key())),
            self.timestamp,
        )

    def __str__(self):
        return '<Query key={}, unixtime={}'.format(
            '_'.join(map(str, self.key())),
            self.timestamp,
        )

    @staticmethod
    def _parse_unixtime(unixtime):
        if unixtime is None:
            return None

        return datetime.utcfromtimestamp(unixtime)

    @staticmethod
    def _parse_date(date):
        return datetime.strptime(date, '%Y-%m-%d') if date else None


SEARCHES_QUERY_TEMPLATE = '''
USE hahn;

$SERVICE_TO_ID = ToDict(AsList(
    AsTuple('ticket', 1),
    AsTuple('m_avia', 2),
    AsTuple('api_avia', 3),
    AsTuple('sovetnik', 4),
    AsTuple('42', 5),
    AsTuple('tours', 6),
    AsTuple('welcome2018', 7),
    AsTuple('rasp_api_public', 8),
    AsTuple('rasp_mobile', 9),
    AsTuple('rasp_morda', 10),
    AsTuple('rasp_morda_backend', 11),
    AsTuple('rasp_touch', 12),
    AsTuple('wizard', 13),
    AsTuple('email_subscriptions', 14),
    AsTuple('yeah', 15),
    AsTuple('avia-travel', 18),
    AsTuple('mavia-travel', 19),
    AsTuple('wizard_user', 998),
    AsTuple('rasp', 999)
));

$ParseRaspPassenger = ($pass, $default) -> {{
    RETURN IF(
        $pass == 'null',
        $default,
        CAST($pass AS Int64)
    )
}};

$all_searches = (
SELECT
    fromId AS from_id,
    toId AS to_id,
    adult_seats AS adults,
    children_seats AS children,
    infant_seats AS infants,
    `when` AS forward_date,
    `return_date` AS backward_date,
    $SERVICE_TO_ID[service] AS service,
    unixtime AS unixtime
FROM RANGE(
    `logs/avia-users-search-log/1d`,
    '{left_date}',
    '{right_date}'
)
UNION ALL

SELECT
    parsing.from_point_key AS from_id,
    parsing.to_point_key AS to_id,
    1 AS adults,
    0 AS children,
    0 AS infants,
    query.departure_date AS forward_date,
    IF(
        query.departure_date IS NULL,
        NULL,
        query.return_date
    ) AS backward_date,
    $SERVICE_TO_ID[IF(query.departure_date IS NULL, 'wizard_user', 'wizard')] AS service,
    query.unixtime AS unixtime
FROM RANGE(
    `logs/avia-wizard-query-log/1d`,
    '{left_date}',
    '{right_date}'
) AS query
INNER JOIN RANGE(
    `logs/avia-wizard-point-parse-log/1d`,
    '{left_date}',
    '{right_date}'
) AS parsing
USING (job_id)
WHERE query.departure_date IS NOT NULL
UNION ALL

SELECT
    from_id,
    to_id,
    $ParseRaspPassenger(adults, 1) AS adults,
    $ParseRaspPassenger(children, 0) AS children,
    $ParseRaspPassenger(infants, 0) AS infants,
    `when` AS forward_date,
    IF(return_date == 'null', NULL, return_date) AS backward_date,
    $SERVICE_TO_ID[service] AS service,
    CAST(unixtime AS UInt64) AS unixtime
FROM RANGE(
    `logs/rasp-users-search-log/1d`,
    '{left_date}',
    '{right_date}'
)
WHERE service == 'rasp' AND transport_type = 'plane' AND `when` != 'undefined'
);

INSERT INTO
    `{output_table}`
WITH TRUNCATE

SELECT
    from_id,
    to_id,
    adults,
    children,
    infants,
    forward_date,
    backward_date,
    service,
    unixtime
FROM $all_searches AS searches
LEFT SEMI JOIN `{heater_table}` AS heater_queries
USING (from_id, to_id, adults, children, infants, forward_date, backward_date)
'''


HEATER_SELECT_QUERY_TEMPLATE = '''
USE hahn;

$SERVICE_TO_ID = ToDict(AsList(
    AsTuple('ticket', 1),
    AsTuple('m_avia', 2),
    AsTuple('api_avia', 3),
    AsTuple('sovetnik', 4),
    AsTuple('42', 5),
    AsTuple('tours', 6),
    AsTuple('welcome2018', 7),
    AsTuple('rasp_api_public', 8),
    AsTuple('rasp_mobile', 9),
    AsTuple('rasp_morda', 10),
    AsTuple('rasp_morda_backend', 11),
    AsTuple('rasp_touch', 12),
    AsTuple('wizard', 13),
    AsTuple('email_subscriptions', 14),
    AsTuple('yeah', 15),
    AsTuple('avia-travel', 18),
    AsTuple('mavia-travel', 19),
    AsTuple('rasp', 999)
));

$GetPointKey = ($airportId, $cityId) -> {{
    RETURN IF(
        $airportId IS NULL,
        "c" || CAST($cityId AS String),
        "s" || CAST($airportId AS String)
    )
}};
$format = DateTime::Format("%Y-%m-%d");
$UnixtimeToDate = ($unixtime) -> {{
    RETURN IF($unixtime IS NULL,
        NULL,
        $format(DateTime::FromSeconds(CAST($unixtime AS Uint32)))
    )
}};

$all_partner_queries = (
SELECT
    $GetPointKey(SOME(from_airport_id), SOME(from_settlement_id)) AS from_id,
    $GetPointKey(SOME(to_airport_id), SOME(to_settlement_id)) AS to_id,
    SOME(adults) AS adults,
    SOME(children) AS children,
    SOME(infants) AS infants,
    $UnixtimeToDate(SOME(forward_date)) AS forward_date,
    $UnixtimeToDate(SOME(backward_date)) AS backward_date,
    SOME(service_id) AS service,
    SOME(unixtime) AS unixtime
FROM RANGE(
    '//logs/avia-variants-log/1d',
    '{start_date}',
    '{end_date}'
)
GROUP BY query_id
);


INSERT INTO
    `{output_table}`
WITH TRUNCATE

SELECT
    from_id,
    to_id,
    adults,
    children,
    infants,
    forward_date,
    backward_date,
    service,
    unixtime
FROM $all_partner_queries AS queries
LEFT SEMI JOIN (
    SELECT
        from_id,
        to_id,
        adults,
        children,
        infants,
        forward_date,
        backward_date
    FROM $all_partner_queries
    WHERE service = $SERVICE_TO_ID['yeah']
) AS heater_queries
USING (from_id, to_id, adults, children, infants, forward_date, backward_date)
'''


class AviaHeaterStat(AviaBaseTask):
    _yt_client = None
    _yql_client = None

    class Requirements(sdk2.Task.Requirements):
        cores = 1
        ram = 8192

        class Caches(sdk2.Requirements.Caches):
            pass  # We do not need caches

        environments = (
            PipEnvironment('requests'),
            PipEnvironment('yandex-yt', version='0.10.8'),
            PipEnvironment('yandex-yt-yson-bindings-skynet', version='0.3.32-0'),
            PipEnvironment('yql'),
        )

    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.Group('Vault settings') as vault_block:
            vault_owner = sdk2.parameters.String('Vaults owner', required=True)

        with sdk2.parameters.Group('Map-Reduce settings') as mr_block:
            yt_token_vault = sdk2.parameters.String('Token vault name', required=True)

        with sdk2.parameters.Group('YQL settings') as yql_block:
            yql_token_vault = sdk2.parameters.String('YQL token vault name', required=True)

        with sdk2.parameters.Group('Statface settings') as stat_block:
            stat_vault = sdk2.parameters.String('Statface token vault name', required=True)
            report_name = sdk2.parameters.String('Statface report name', required=True)
            use_beta = sdk2.parameters.Bool('Use beta version', required=False, default=False)

        with sdk2.parameters.Group('Date settings') as date_block:
            left_date = sdk2.parameters.String('Start date (deafult day before yesterday)', required=False)
            right_date = sdk2.parameters.String('End date (deafult day before yesterday)', required=False)

        with sdk2.parameters.Group('Task settings') as t_block:
            times = sdk2.parameters.List('Time deltas (in minutes)', value_type=sdk2.parameters.Integer)

    def _get_yt_client(self):
        if self._yt_client is None:
            import yt.wrapper as yt
            self._yt_client = yt.YtClient(
                proxy='hahn',
                token=self._get_token('yt'),
            )

        return self._yt_client

    def get_search_right_border(self, right_border):
        max_delta = AviaHeaterStat.time_parameter_to_timedelta(max(self.Parameters.times))
        days = int(math.ceil(max_delta.total_seconds() / (60 * 60 * 24)))
        return right_border + timedelta(days=days)

    def _get_yql_client(self):
        if self._yql_client is None:
            from yql.api.v1.client import YqlClient
            self._yql_client = YqlClient(
                token=self._get_token('yql'),
            )

        return self._yql_client

    def extract_queries_to_partner(self, left_date, right_date):
        yt_client = self._get_yt_client()
        temp_table = yt_client.create_temp_table()
        yql_client = self._get_yql_client()
        query = HEATER_SELECT_QUERY_TEMPLATE.format(
            start_date=left_date.strftime('%Y-%m-%d'),
            end_date=right_date.strftime('%Y-%m-%d'),
            output_table=temp_table,
        )
        r = yql_client.query(query, syntax_version=1)

        r.run()
        logging.info('YQL Operation: %s', yqlh.get_yql_operation_url(r))
        r.wait_progress()

        if not r.is_success:
            logging.error('Operation failed. Status:', r.status)
            if r.errors:
                for error in r.errors:
                    logging.error(' - %s', str(error))
            raise SandboxTaskFailureError('YQL query failed')

        return sort_by_timestamp(
            Query.from_merged(record)
            for record in yt_client.read_table(temp_table)
        ), temp_table

    @staticmethod
    def time_parameter_to_timedelta(time):
        return timedelta(minutes=time)

    def get_useful_queries_to_partners(self, queries_to_partners, searches):
        useful_queries = defaultdict(lambda: defaultdict(int))  # Service -> time -> count

        last_non_heater_queries = {time: {} for time in self.Parameters.times}
        last_heater_queries = {time: {} for time in self.Parameters.times}
        heater_idxs = {time: 0 for time in self.Parameters.times}

        for search in searches:
            search_key = search.key()
            for time in self.Parameters.times:
                min_timestamp = search.timestamp - AviaHeaterStat.time_parameter_to_timedelta(time)
                partner_query_ind = heater_idxs[time]
                while (
                    partner_query_ind < len(queries_to_partners) and
                    queries_to_partners[partner_query_ind].timestamp >= min_timestamp
                ):
                    heater_query = queries_to_partners[partner_query_ind]
                    if heater_query.source == SERVICE_TO_ID['yeah']:
                        last_heater_queries[time][heater_query.key()] = heater_query.timestamp
                    else:
                        last_non_heater_queries[time][heater_query.key()] = heater_query.timestamp

                    partner_query_ind += 1

                heater_idxs[time] = partner_query_ind

                last_heater_timestamp = last_heater_queries[time].get(search_key)
                last_non_heater_timestamp = last_non_heater_queries[time].get(search_key)

                if (
                    last_heater_timestamp is not None and  # There was a heater query
                    (last_non_heater_timestamp is None or min_timestamp > last_non_heater_timestamp)
                    # There were no such searches before or they were too early
                ):
                    useful_queries[search.source][time] += 1

        if None in useful_queries:
            logging.warning('Unknown service in report')
            useful_queries.pop(None)

        return useful_queries

    def extract_searches(self, left_date, right_date, heater_table):
        yt_client = self._get_yt_client()
        temp_table = yt_client.create_temp_table()
        yql_client = self._get_yql_client()
        query = SEARCHES_QUERY_TEMPLATE.format(
            left_date=left_date.strftime('%Y-%m-%d'),
            right_date=right_date.strftime('%Y-%m-%d'),
            output_table=temp_table,
            heater_table=heater_table,
        )
        r = yql_client.query(query, syntax_version=1)

        r.run()
        logging.info('YQL Operation: %s', yqlh.get_yql_operation_url(r))
        r.wait_progress()

        if not r.is_success:
            logging.error('Operation failed. Status:', r.status)
            if r.errors:
                for error in r.errors:
                    logging.error(' - %s', str(error))
            raise SandboxTaskFailureError('YQL query failed')

        return sort_by_timestamp(
            Query.from_merged(record)
            for record in yt_client.read_table(temp_table)
        )

    def on_prepare(self):
        configure_logging(
            sentry_dsn=get_sentry_dsn(self)
        )
        super(AviaHeaterStat, self).on_prepare()

    def on_execute(self):
        logging.info('Start')

        left_date = self._parse_date(self.Parameters.left_date).date()
        logging.info('Left date: %s', left_date.strftime('%Y-%m-%d'))
        right_date = self._parse_date(self.Parameters.right_date).date()
        logging.info('Right date: %s', right_date.strftime('%Y-%m-%d'))

        search_right_border = self.get_search_right_border(right_date)
        logging.info('Search right date: %s', search_right_border.strftime('%Y-%m-%d'))

        heater_queries, table = self.extract_queries_to_partner(left_date, right_date)
        searches = self.extract_searches(left_date, search_right_border, table)
        report = self.get_useful_queries_to_partners(heater_queries, searches)

        logging.info('Useful queries: %r', report)
        self.send_report_to_stat(report)
        logging.info('End')

    def send_report_to_stat(self, report):
        logging.info('Start uploading report')
        if not report:
            logging.info('Nothing to upload: report is empty: %r', report)
            return

        response = sth.post_data_to_stat(
            name=self.Parameters.report_name,
            data=self._prepare_report_to_stat(report),
            token=self._get_token('statface'),
            scale='d',
            beta=self.Parameters.use_beta,
        )

        if response.status_code != 200:
            raise SandboxTaskFailureError(
                'Can not upload to Statface. Status: {}, message: {}'.format(
                    response.status_code,
                    response.content,
                )
            )

        logging.info('Upload succeed')

    def _prepare_report_to_stat(self, report):
        now = datetime.now().date().strftime('%Y-%m-%d')
        return [
            {
                'fielddate': now,
                'source': SERVICE_ID_TO_NAME[source],
                'time': time,
                'n_queries': n_queries,
            }
            for source, source_dct in report.iteritems()
            for time, n_queries in source_dct.iteritems()
        ]

    def _parse_date(self, s):
        return datetime.strptime(s, '%Y-%m-%d') if s else DAY_BEFORE_YESTERDAY

    def _get_token(self, system):
        if system == 'statface':
            name = self.Parameters.stat_vault

        elif system == 'yt':
            name = self.Parameters.yt_token_vault

        elif system == 'yql':
            name = self.Parameters.yql_token_vault

        else:
            raise ValueError('Unknown system: {}'.format(system))

        return sdk2.Vault.data(self.Parameters.vault_owner, name)
