# -*- coding: utf-8 -*-
from textwrap import dedent
import calendar
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import yt.wrapper as yt
from collections import namedtuple
from enum import Enum

from datacloud.dev_utils.time.patterns import FMT_DATE_YM
from datacloud.money.price_policy_table import PricePolicyTable
from datacloud.money.score_to_contract_table import ScoreToCotractTable

SEPARATOR = '#'
VAT_BEFORE_2019 = 0.18
VAT_AFTER_2019 = 0.20  # ;((

DtRange = namedtuple('DtRange', ['first_dt', 'last_dt'])
PolicyKey = namedtuple('PolicyKey', ['partner_id', 'contract_id'])
DayTuple = namedtuple('DayTuple', ['year', 'month', 'day'])


class AggTypes(Enum):
    daily = 'daily'
    monthly = 'monthly'


def requests_mapper(rec):
    id_values = []
    for id_type in ('emails', 'email_id_vals', 'phones', 'phone_id_vals', 'cookies'):
        id_values.extend(id_value for id_value in sorted(rec[id_type]))

    yield {
        'request_id': rec['request_id'],
        'id_value': SEPARATOR.join(id_values)
    }


join_req_resp = dedent("""\
    pragma yt.ForceInferSchema;
    INSERT INTO `{result_table}` WITH TRUNCATE
    SELECT DISTINCT
        resps.request_id as request_id,
        reqs.id_value as id_value,
        resps.partner_id as partner_id,
        resps.score_name as score_name,
        resps.ts as ts,
        resps.response_time as response_time,
        resps.status as status,
        resps.has_score as has_score
    FROM `{requests_table}` as reqs
    INNER JOIN RANGE(`{responses_table}`) as resps
    ON reqs.request_id == resps.request_id""")


def get_last_month(rec):
    now = datetime.now() + relativedelta(days=-1)
    if rec['last_month'] is None:
        return now
    else:
        last_rec_month = datetime.strptime(rec['last_month'], FMT_DATE_YM) + relativedelta(months=1, days=-1)
        return min(now, last_rec_month)


class ScoreToContract:
    def __init__(self):
        recs = ScoreToCotractTable().list_records()

        self.score2contract = dict()
        for rec in recs:
            cur_month_dt = datetime.strptime(rec['first_month'], FMT_DATE_YM)
            last_month_dt = get_last_month(rec)
            while cur_month_dt <= last_month_dt:
                cur_month = cur_month_dt.strftime(FMT_DATE_YM)
                self.set_contract(rec, cur_month)
                cur_month_dt += relativedelta(months=1)

    def set_contract(self, rec, month):
        self.score2contract[(rec['partner_id'], rec['score_name'], month)] = rec['contract_id']

    def get_contract(self, rec, month):
        return self.score2contract.get((rec['partner_id'], rec['score_name'], month))

    def __call__(self, rec):
        if rec['ts'] is None:
            return

        current_dt = datetime.utcfromtimestamp(rec['ts'])
        contract_id = self.get_contract(rec, current_dt.strftime(FMT_DATE_YM))

        if contract_id is not None:
            rec.update({
                'contract_id': contract_id,
                'year': current_dt.year,
                'month': current_dt.month,
                'day': current_dt.day
            })
            rec.pop('score_name')
            yield rec


def get_policy_key(rec):
    return PolicyKey(rec['partner_id'], rec['contract_id'])


def make_date_ranges_dict():
    score2contract_recs = ScoreToCotractTable().list_records()
    dt_ranges = dict()
    for rec in score2contract_recs:
        first_month = datetime.strptime(rec['first_month'], FMT_DATE_YM)
        last_month = get_last_month(rec)

        dt_ranges[get_policy_key(rec)] = DtRange(first_month, last_month)

    return dt_ranges


class DaysFiller:
    def __init__(self):
        self.dt_ranges = make_date_ranges_dict()

    def make_recs_for_contract(self, policy_key, dt_range):
        cur_dt = dt_range.first_dt
        recs = []

        while cur_dt <= dt_range.last_dt:
            recs.append({
                'partner_id': policy_key.partner_id,
                'contract_id': policy_key.contract_id,
                'year': cur_dt.year,
                'month': cur_dt.month,
                'day': cur_dt.day
            })
            cur_dt += timedelta(days=1)

        return recs

    def __call__(self, yt_client, path):
        recs = []
        for policy_key, dt_range in self.dt_ranges.items():
            recs.extend(self.make_recs_for_contract(policy_key, dt_range))

        res_table = yt_client.TablePath(
            path,
            schema=[
                {'name': 'partner_id', 'type': 'string'},
                {'name': 'contract_id', 'type': 'string'},
                {'name': 'year', 'type': 'uint16'},
                {'name': 'month', 'type': 'uint8'},
                {'name': 'day', 'type': 'uint8'},
            ]
        )
        yt_client.write_table(res_table, recs)
        yt_client.run_sort(res_table, sort_by=['partner_id', 'contract_id', 'year', 'month', 'day'])


@yt.with_context
class MoneyReducer:
    hit_price_types = set(('hit', 'hit_unique'))
    request_price_types = set(('has_score', 'no_score', 'requests', 'answers'))
    fixed_price_type = 'fixed'

    def __init__(self):
        policies_list = PricePolicyTable().list_records()
        self.policies = {
            get_policy_key(rec): rec['policy']
            for rec in policies_list if rec['is_active']
        }
        self.dt_ranges = make_date_ranges_dict()

    def calc_price_by_hit(self, requests, hit, buckets, prices):
        for i, bucket in enumerate(buckets):
            if hit < bucket:
                return requests * prices[i]

        return requests * prices[-1]

    def calc_price_by_requests(self, requests, buckets, prices):
        money = 0
        prev_bucket = 0
        all_buckets_hit = True
        for i, bucket in enumerate(buckets):
            if requests - prev_bucket < bucket - prev_bucket:
                money += (requests - prev_bucket) * prices[i]
                all_buckets_hit = False
                break
            else:
                money += (bucket - prev_bucket) * prices[i]
                prev_bucket = bucket

        if all_buckets_hit:
            money += (requests - prev_bucket) * prices[-1]

        return money

    def calc_money_for_system(self, rec, system, policy):
        bucketing = system.get('bucketing')
        buckets = policy.get('buckets', {}).get(bucketing, None)

        price_for = system.get('price_for', system['pricing'])
        num_to_pay_for = 1
        if price_for != self.fixed_price_type:
            num_to_pay_for = rec[price_for]

        pricing = system['pricing']
        prices = policy['prices'][pricing]
        bucket_for = system.get('bucket_for', bucketing)

        if buckets is None:
            return num_to_pay_for * prices
        elif bucket_for in self.hit_price_types:
            return self.calc_price_by_hit(num_to_pay_for, rec[bucket_for], buckets, prices)
        elif bucket_for in self.request_price_types:
            return self.calc_price_by_requests(num_to_pay_for, buckets, prices)
        else:
            raise ValueError('Bad bucket_for: {}'.foramt(bucket_for))

    def is_no_process_rec(self, rec):
        policy_key = get_policy_key(rec)
        if policy_key not in self.policies or policy_key not in self.dt_ranges:
            return True

        dt = datetime(rec['year'], rec['month'], rec['day'])
        dt_range = self.dt_ranges[policy_key]
        if dt < dt_range.first_dt or dt > dt_range.last_dt:
            return True

        return False

    def get_vat(self, day_tuple):
        if day_tuple.year < 2019:
            return VAT_BEFORE_2019
        return VAT_AFTER_2019

    def get_money_dict(self, rec, day_tuple):
        policy = self.policies[get_policy_key(rec)]
        money_no_min = 0
        for system in policy['systems']:
            money_no_min += self.calc_money_for_system(rec, system, policy)

        money = max(money_no_min, policy.get('min_monthly_price', 0))
        bill = money
        vat_coef = 1 + self.get_vat(day_tuple)
        if policy['VAT_included']:
            money /= vat_coef
            money_no_min /= vat_coef
        else:
            bill *= vat_coef

        return {
            'money': float(money),
            'money_no_min': float(money_no_min),
            'bill': float(bill)
        }

    def prepare_rec(self, key, day, has_score, no_score, timeout, requests,
                    id_vals, id_vals_if_score, ts):

        answers = has_score + no_score
        rec = dict(
            partner_id=key['partner_id'],
            contract_id=key['contract_id'],
            year=key['year'], month=key['month'], day=day.day,
            has_score=has_score, no_score=no_score, timeout=timeout,
            requests=requests, answers=answers, ts=ts,
            hit_unique=float(len(id_vals_if_score)) / len(id_vals) if len(id_vals) > 0 else 0.0,
            hit=float(has_score) / answers if answers > 0 else 0.0,
            timeout_share=float(timeout) / requests if requests > 0 else 0.0
        )
        rec.update(self.get_money_dict(rec, day))

        return rec

    def __call__(self, key, recs, context):
        days = []
        cur_rec = None
        for rec in recs:
            if context.table_index == 0:
                days.append(DayTuple(rec['year'], rec['month'], rec['day']))
            else:
                cur_rec = rec
                break

        has_score, no_score, timeout, requests = (0, 0, 0, 0)
        id_vals, id_vals_if_score = set(), set()
        prev_req_id = None
        for day in days:
            ts = float(calendar.timegm(datetime(day.year, day.month, day.day).timetuple()))
            yield self.prepare_rec(key, day, has_score, no_score, timeout,
                                   requests, id_vals, id_vals_if_score, ts)
            while cur_rec is not None and DayTuple(cur_rec['year'], cur_rec['month'], cur_rec['day']) == day:
                requests += 1 if cur_rec['request_id'] != prev_req_id else 0
                if cur_rec['status'][:1] == '2':
                    id_vals.add(cur_rec['id_value'])
                    if cur_rec['has_score']:
                        has_score += 1
                        id_vals_if_score.add(cur_rec['id_value'])
                    else:
                        no_score += 1
                else:
                    timeout += 1 if cur_rec['request_id'] != prev_req_id else 0

                yield self.prepare_rec(key, day, has_score, no_score, timeout,
                                       requests, id_vals, id_vals_if_score, cur_rec['ts'])
                prev_req_id = cur_rec['request_id']
                cur_rec = next(recs, None)


def get_aggregate_query(input_table, output_table, agg_type=AggTypes.daily):
    if agg_type == AggTypes.daily:
        partition_by = 'partner_id, contract_id, year, month, day'
        where = 'rn == 1'
    elif agg_type == AggTypes.monthly:
        partition_by = 'partner_id, contract_id, year, month'
        where = 'rn == 1 AND (month < DateTime::GetMonth(CurrentUtcDate()) OR year < DateTime::GetYear(CurrentUtcDate()))'
    else:
        raise ValueError('Bad agg_type: {}'.format(agg_type))

    return dedent("""
        pragma yt.ForceInferSchema;
        INSERT INTO `{output_table}` WITH TRUNCATE
        SELECT * WITHOUT rn, ts
        FROM (
            SELECT
                partner_id,
                contract_id,
                year,
                month,
                day,
                FIRST_VALUE(ts) OVER w as ts,
                FIRST_VALUE(money) OVER w as money,
                FIRST_VALUE(money_no_min) OVER w as money_no_min,
                FIRST_VALUE(bill) OVER w as bill,
                FIRST_VALUE(hit_unique) OVER w as hit_unique,
                FIRST_VALUE(hit) OVER w as hit,
                FIRST_VALUE(has_score) OVER w as has_score,
                FIRST_VALUE(no_score) OVER w as no_score,
                FIRST_VALUE(answers) OVER w as answers,
                FIRST_VALUE(requests) OVER w as requests,
                FIRST_VALUE(timeout) OVER w as timeout,
                FIRST_VALUE(timeout_share) OVER w as timeout_share,
                ROW_NUMBER() OVER w as rn
            FROM `{input_table}`
            WINDOW w AS (
                PARTITION BY {partition_by}
                ORDER BY ts DESC
            )
        )
        WHERE {where}
        ORDER BY {partition_by}
        """).format(input_table=input_table,
                    output_table=output_table, partition_by=partition_by,
                    where=where)


select_last_days = dedent("""
    pragma yt.ForceInferSchema;
    INSERT INTO `{output_table}` WITH TRUNCATE
    SELECT * WITHOUT rn
    FROM (
        SELECT partner_id, contract_id, year, month, day, money,
            money_no_min, bill, has_score, no_score, answers, requests,
            hit, hit_unique, timeout_share, ROW_NUMBER() OVER w as rn
        FROM `{input_table}`
        WINDOW w AS (
            PARTITION BY partner_id, contract_id
            ORDER BY year DESC, month DESC, day DESC
        )
    )
    WHERE rn <= {n}
    ORDER BY partner_id, contract_id, year, month, day
""")


def filter_fast_logs(key, recs):
    for rec in recs:
        if rec['req_type'] == 'scores' and rec['status'][:1] == '2':
            yield rec
            return
    if rec['req_type'] == 'scores' and rec['status'][:1] == '5':
        yield rec


def fast_logs_format_mapper(rec):
    user_ids = rec['req_body'].get('user_ids', {})
    if type(user_ids) is not dict:
        user_ids = {}

    emails, email_id_vals = [], []
    for email in user_ids.get('emails', []):
        if 'email' in email:
            emails.append(email['email'])
        if 'id_value' in email:
            email_id_vals.append(email['id_value'])

    phones, phone_id_vals = [], []
    for phone in user_ids.get('phones', []):
        if 'phone' in phone:
            phones.append(phone['phone'])
        if 'id_value' in phone:
            phone_id_vals.append(phone['id_value'])

    cookies = []
    for cookie in user_ids.get('cookies', []):
        cookies.append(cookie['cookie_vendor'] + '#' + cookie['cookie'])

    id_values = []
    for ids_container in (emails, email_id_vals, phones, phone_id_vals, cookies):
        id_values.extend(id_value for id_value in sorted(ids_container))

    dt = datetime.strptime(rec['timestamp'], '%Y-%m-%dT%H:%M:%S+00:00')
    ts = float(calendar.timegm(dt.timetuple()))

    if rec['status'][:1] == '2':
        for resp_score in rec['resp_body']['scores']:
            yield {
                'ts': ts,
                'partner_id': rec['partner_id'],
                'id_value': '#'.join(id_values),
                'response_time': float(rec['resp_time']),
                'score_name': resp_score['score_name'],
                'has_score': resp_score['has_score'],
                'request_id': rec['blnsr_req_id'],
                'status': rec['status']
            }
    elif rec['status'][:1] == '5':
        for req_score in rec['req_body']['scores']:
            yield {
                'ts': ts,
                'partner_id': rec['partner_id'],
                'id_value': '#'.join(id_values),
                'response_time': float(rec['resp_time']),
                'score_name': req_score['score_name'],
                'has_score': None,
                'request_id': rec['blnsr_req_id'],
                'status': rec['status']
            }
    else:
        raise ValueError('Unknown status: {}'.format(rec['status']))
