import logging
from datetime import datetime, timedelta

from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment

from sandbox.projects.avia.base import AviaBaseTask
from sandbox.projects.avia.lib import (
    statface_helpers as sth,
)


def post_df_to_stat(name, df, token, scale='d', beta=False, _append_mode=1):
    return sth.post_data_to_stat(
        name,
        df.to_dict(orient='records'),
        token=token,
        scale=scale,
        beta=beta,
        _append_mode=_append_mode,
    )


QUERY_TEMPLATE = '''
PRAGMA SimpleColumns;
PRAGMA yt.DefaultMemoryLimit = '2G';

USE hahn;

$pass_in_row = ($adult_seats, $children_seats) -> {{
    return CAST($adult_seats AS Int64) + CAST($children_seats AS Int64);
}};

$redir_stat = (
    SELECT
        from_id,
        to_id,
        billing_order_id,
        is_one_way,
        fielddate,
        month,
        SUM($pass_in_row(ADULT_SEATS, CHILDREN_SEATS)) AS n_passangers,
        SUM_IF(CAST(OFFER_PRICE AS Double), OFFER_CURRENCY= "RUR") AS price,
        COUNT_IF(OFFER_CURRENCY="RUR") AS rur_redirects,
        COUNT(*) AS n_redirects
    FROM RANGE(
        `//home/avia/logs/avia-redir-balance-by-day-log`,
        "{left_date}",
        "{right_date}"
    )
    WHERE
        CAST(FILTER AS Int64) == 0 AND
        (NATIONAL_VERSION == "ru" OR NATIONAL == "ru")
    GROUP BY
        FROMID AS from_id,
        TOID AS to_id,
        CAST(BILLING_ORDER_ID AS Int64) AS billing_order_id,
        RETURN_DATE == "" AS is_one_way,
        CAST(SUBSTRING(`WHEN`, 5, 2) AS Int64) AS month,
        SUBSTRING(ISO_EVENTTIME, 0, 10) AS fielddate
);

$old_vat = 1.18;
$vat = 1.20;
$default_price = 23.0 / $old_vat * $vat;
$old_price = 23.0 / $old_vat * $vat;


$almost_done = (
    SELECT
        CAST(SUM(COALESCE(
            price_list.price * $vat * redir_stat.n_passangers,
            $default_price * redir_stat.n_redirects
        )) AS Float) AS new_revenue,

        $old_price * SUM(redir_stat.n_redirects) AS old_revenue,

        SUM(redir_stat.n_redirects) AS n_redirects,

        CAST(SUM(COALESCE(
            price_list.price * $vat * redir_stat.n_passangers,
            $default_price * redir_stat.n_redirects
        )) AS Float) / ($old_price * SUM(redir_stat.n_redirects)) AS profit,

        CAST(SUM(redir_stat.price) AS Float) / SUM(redir_stat.rur_redirects) AS mean_price,
        CAST(SUM(redir_stat.n_passangers) AS Float) / SUM(redir_stat.n_redirects) AS mean_passangers,
        fielddate,
        billing_order_id
    FROM $redir_stat AS redir_stat
    LEFT JOIN hahn.`//home/avia/order_pricelist/2019-08-25` AS price_list
    USING (from_id, to_id, is_one_way, month)
    GROUP BY
        redir_stat.fielddate AS fielddate,
        redir_stat.billing_order_id AS billing_order_id
);

SELECT
    partners.code AS partner,
    report.fielddate AS fielddate,
    report.new_revenue AS new_revenue,
    report.old_revenue AS old_revenue,
    report.n_redirects AS n_redirects,
    report.profit AS profit,
    report.mean_price AS mean_price,
    report.mean_passangers AS mean_passangers
FROM $almost_done AS report
LEFT JOIN `//home/rasp/reference/partner` AS partners
USING (billing_order_id)
'''


def weighted_mean(value, weight):
    return (value * weight).sum() / weight.sum()


def aggregate_results(results, key='fielddate'):
    grouped = results.groupby(key)
    grouped_results = grouped[['old_revenue', 'new_revenue', 'n_redirects']].sum()

    without_dohop = results[results['partner'] != 'dohop'].groupby(key)
    for column in ('old_revenue', 'new_revenue', 'n_redirects'):
        new_column = '{}_without_dohop'.format(column)
        grouped_results.loc[:, new_column] = without_dohop[column].sum()

    grouped_results.loc[:, 'profit'] = grouped_results['new_revenue'] / grouped_results['old_revenue']
    grouped_results.loc[:, 'profit_without_dohop'] = grouped_results['new_revenue_without_dohop'] / grouped_results['old_revenue_without_dohop']

    for column in ('mean_price', 'mean_passangers'):
        grouped_results.loc[:, column] = grouped.apply(
            lambda group: (group[column] * group['n_redirects']).sum() / group['n_redirects'].sum()
        )

    return grouped_results.reset_index()


class AviaProfitReport(AviaBaseTask):
    """ Report for Avia pricing policy (2018 year) """
    _yt_client = None
    _yql_client = None

    class Requirements(sdk2.Task.Requirements):
        cores = 1
        ram = 1024

        class Caches(sdk2.Requirements.Caches):
            pass  # We do not need caches

        environments = (
            PipEnvironment('requests'),
            PipEnvironment('yandex-yt', version='0.10.8'),
            PipEnvironment('yandex-yt-yson-bindings-skynet', version='0.3.32-0'),
            PipEnvironment('yql'),
        )

    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.Group('Vault settings') as vault_block:
            vault_owner = sdk2.parameters.String('Vaults owner', required=True)

        with sdk2.parameters.Group('Map-Reduce settings') as mr_block:
            yt_token_vault = sdk2.parameters.String('Token vault name', required=True)

        with sdk2.parameters.Group('YQL settings') as yql_block:
            yql_token_vault = sdk2.parameters.String('YQL token vault name', required=True)

        with sdk2.parameters.Group('Statface settings') as stat_block:
            stat_vault = sdk2.parameters.String('Statface token vault name', required=True)
            report_name = sdk2.parameters.String('Statface report name', required=True)
            agg_name = sdk2.parameters.String('Staface aggregated report name', required=True)
            use_beta = sdk2.parameters.Bool('Use beta version', required=False, default=False)

        with sdk2.parameters.Group('Date settings') as date_block:
            left_date = sdk2.parameters.String('Start date (deafult day before yesterday)', required=False)
            right_date = sdk2.parameters.String('End date (deafult yesterday)', required=False)

    def _get_yt_client(self):
        if self._yt_client is None:
            import yt.wrapper as yt
            self._yt_client = yt.YtClient(
                proxy='hahn',
                token=self._get_token('yt'),
            )

        return self._yt_client

    def _get_yql_client(self):
        if self._yql_client is None:
            from yql.api.v1.client import YqlClient
            self._yql_client = YqlClient(
                token=self._get_token('yql'),
            )

        return self._yql_client

    def on_execute(self):
        logging.info('Start')
        yesterday = datetime.today() - timedelta(days=1)
        day_before_yesterday = datetime.today() - timedelta(days=2)

        left_date = self._parse_date(self.Parameters.left_date, day_before_yesterday).date()
        logging.info('Left date: %s', left_date.strftime('%Y-%m-%d'))
        right_date = self._parse_date(self.Parameters.right_date, yesterday).date()
        logging.info('Right date: %s', right_date.strftime('%Y-%m-%d'))

        query = QUERY_TEMPLATE.format(
            left_date=left_date.strftime('%Y-%m-%d'),
            right_date=right_date.strftime('%Y-%m-%d'),
        )

        client = self._get_yql_client()
        r = client.query(query, syntax_version=1)
        r.run()

        results = r.get_results()

        if not r.is_success:
            for error in r.errors:
                logging.error(' - %s', str(error))
            logging.exception('ERROR: %r', r.status)
            return 0

        results = r.full_dataframe.fillna(0)
        logging.info('Len: %d', len(results))
        for column in ['old_revenue', 'new_revenue', 'mean_price', 'mean_passangers', 'n_redirects']:
            results.loc[:, column] = results[column].astype('float')

        grouped_results = aggregate_results(results)

        r = post_df_to_stat(
            self.Parameters.report_name,
            results,
            self._get_token('statface'),
            beta=self.Parameters.use_beta,
            _append_mode=0 if self.Parameters.use_beta else 1,
        )

        if r.status_code != 200:
            logging.error('Stat code: %d, content: %r', r.status_code, r.content)

        r = post_df_to_stat(
            self.Parameters.agg_name,
            grouped_results,
            self._get_token('statface'),
            beta=self.Parameters.use_beta,
            _append_mode=0 if self.Parameters.use_beta else 1,
        )

        if r.status_code != 200:
            logging.error('Stat code: %d, content: %r', r.status_code, r.content)

    def _parse_date(self, s, default):
        return datetime.strptime(s, '%Y-%m-%d') if s else default

    def _get_token(self, system):
        if system == 'statface':
            name = self.Parameters.stat_vault

        elif system == 'yt':
            name = self.Parameters.yt_token_vault

        elif system == 'yql':
            name = self.Parameters.yql_token_vault

        else:
            raise ValueError('Unknown system: {}'.format(system))

        return sdk2.Vault.data(self.Parameters.vault_owner, name)
