# coding=utf-8
import argparse
from collections import Counter, defaultdict
import os
import json
import datetime
import dateutil
import dateutil.parser
from dateutil import rrule
import time
from itertools import chain

from nile.api.v1 import clusters
from nile.api.v1 import filters as nf
from nile.api.v1 import aggregators as na
from nile.api.v1 import extractors as ne
from nile.api.v1 import Record
from qb2.api.v1 import extractors as qe
from qb2.api.v1 import filters as qf
import numpy as np

# from zoo.discounts_calc_experiments.common_config import (get_project_cluster,
#                                                           json_config_file)
# from zoo.discounts_calc_experiments.common_config import (DM_ORDER, ORDER_PROC,
#                                                           HAHN_DIR)
# from zoo.discounts_calc_experiments.tables_mappers import (order_proc_mapper, extract_gmv_mapper)

# from zoo.utils.nile_helpers.dates import range_selector

from projects.common.nile.dates import range_selector

TARIFFS = ['econom', 'business', 'comfortplus', 'vip',
           'child_tariff', 'minivan', 'standart', 'start', 'express', 'courier']

ABS_SCRIPT_PATH = os.path.abspath(__file__)
SUGGEST_NIRVANA_DIR = os.path.dirname(ABS_SCRIPT_PATH)
HAHN_DIR = "//home/taxi_ml/dev/comfort/experiments/{}/"
ORDER_PROC = '//home/taxi-dwh/raw/mdb/order_proc/'
DM_ORDER = '//home/taxi-dwh/summary/dm_order/'


# TODO: добавить сюда assert на то, что поля присутствуют
def json_config_file(value):
    config_path = os.path.join(SUGGEST_NIRVANA_DIR, value)
    with open(config_path, 'rb') as f:
        data = json.load(f)
    # print data
    assert data.get('experiment_name') is not None
    return data


class GmvDiscountsCommissionReducer(object):
    # periods_list - список пересекающихся периодов
    def __init__(self, periods_list, ml_tags):
        self.periods_list = periods_list
        self.ml_tags = ml_tags

    def get_sum_of_fields(self, record, fields_list):
        return sum(map(lambda x: record.get(x, 0), fields_list))

    def calc_commission(self, record):

        fields_for_plus = ['order_commission', 'order_commission_discount',
                           'shift_commission_wo_vat', 'subsidy_commission']

        fields_for_minus = ['subsidy_value', 'holded_subsidy_value',
                            'dms_value', 'holded_dms_value',
                            'subsidy_commission_discount',
                            'discount_value', 'coupon_use_value']

        return self.get_sum_of_fields(record,
                                      fields_for_plus) - self.get_sum_of_fields(
            record, fields_for_minus)

    def calc_discount(self, record):
        ml_value, other_value = 0, 0
        if (record.get('order_tariff') in ['business', 'comfortplus']):
            ml_value += record.get('discount_value', 0)
        else:
            other_value += record.get('discount_value', 0)
        return ml_value, other_value

    def get_sum_of_filter_fields(self, rec, filter_field):
        return sum(value for entity, value in rec.items() if
                   entity.startswith(filter_field))

    def make_new_record(self, results, periods_list, req_fields):
        rec = defaultdict(lambda: defaultdict(float))
        for period_name in periods_list:
            for req_name in req_fields:
                rec[period_name][req_name] = results.get(period_name, {}).get(
                    req_name, 0)
            rec[period_name]['n_trips'] = self.get_sum_of_filter_fields(
                rec[period_name], 'n_trips')
            rec[period_name]['gmv_sum'] = self.get_sum_of_filter_fields(
                rec[period_name], 'gmv_sum')
            rec[period_name]['commission_sum'] = self.get_sum_of_filter_fields(
                rec[period_name], 'commission_sum')

        new_rec = self.ddict_2_dict(rec)
        return Record(**new_rec)

    def ddict_2_dict(self, d):
        for k, v in d.items():
            if isinstance(v, dict):
                d[k] = self.ddict_2_dict(v)
        return dict(d)

    def __call__(self, groups):
        for key, records in groups:

            results_cnt = defaultdict(lambda: defaultdict(float))
            #             last_record = None

            for record in records:
                for date_from, date_to in self.periods_list:
                    if record.get('moscow_date') is None:
                        continue
                    if date_from <= record.moscow_date < date_to:
                        name = '{}_{}'.format(date_from, date_to)
                        results_cnt[name][
                            'n_trips_{}'.format(record.order_tariff)] += 1
                        results_cnt[name][
                            'gmv_sum_{}'.format(record.order_tariff
                                                )] += record.user_cost
                        results_cnt[name]['commission_sum_{}'.format(
                            record.order_tariff)] += record.order_commission  # self.calc_commission(record)
                        # здесь должна быть проверка на тариф!
                        # мб и бизнес, и комфорт, и комфорт+, и детский
                        discount_ml_value, discount_other_value = self.calc_discount(
                            record)
                        results_cnt[name]['discounts_ml'] += discount_ml_value
                        results_cnt[name][
                            'discounts_other'] += discount_other_value
            #                         last_record = record

            req_fields = ['discounts_ml', 'discounts_other'] + \
                         ['n_trips_{}'.format(tariff) for tariff in TARIFFS] + \
                         ['gmv_sum_{}'.format(tariff) for tariff in TARIFFS] + \
                         ['commission_sum_{}'.format(tariff) for tariff in
                          TARIFFS]
            new_record = self.make_new_record(results_cnt,
                                              ['{}_{}'.format(*period) for
                                               period in self.periods_list],
                                              req_fields
                                              )
            # TODO: сделать переменную, в которую передавать список полей,
            # в которых могут быть написаны
            # pred, crypta_pred, city

            yield Record(
                key,
                new_record
            )


def collect_data(job, yt_work_dir, start_date, finish_date):
    # dates
    path_dm_order = '{}{}'.format(DM_ORDER,
                                  range_selector(start_date, finish_date,
                                                 '%Y-%m'))
    print (path_dm_order)

    dm_orders = job.table(path_dm_order) \
        .filter(qf.defined('moscow_order_dt'), qf.defined('local_order_dttm'),
                nf.and_(
                    nf.equals('status', 'finished'),
                    nf.equals('taxi_status', 'complete'),
                    nf.custom(lambda x: x >= start_date, 'moscow_order_dt'),
                    nf.custom(lambda x: x < finish_date, 'moscow_order_dt'))) \
        .project('order_id', 'discount_value', 'user_cost', 'user_phone_id',
                 'local_order_dttm', 'order_tariff', 'utc_order_dttm',
                 'order_commission', 'order_commission_discount',
                 'shift_commission_wo_vat', 'subsidy_value',
                 'holded_subsidy_value',
                 'dms_value', 'holded_dms_value', 'subsidy_commission',
                 'subsidy_commission_discount', 'coupon_use_value',
                 moscow_date='moscow_order_dt',
                 local_date=ne.custom(lambda x: x[:10], 'local_order_dttm'),
                 utc_date=ne.custom(lambda x: x[:10], 'utc_order_dttm'),
                 is_order=ne.const(1)) \
        .put('{}{}'.format(yt_work_dir,
                           'tmp/dm_orders_{}_{}'.format(start_date,
                                                        finish_date)))
    return job


def run_split_on_users(job, yt_work_dir, group_list, ml_tags, periods_list,
                       start_date, finish_date):
    mid_table = job.table('{}{}'.format(yt_work_dir,
                                        'tmp/dm_orders_{}_{}'.format(start_date,
                                                                     finish_date)))

    print(group_list)
    for group in group_list:
        # TODO: проверить, что после джойна не дублируются поля. иначе они перезапишутся
        group_table = job.table(group).project('user_phone_id').join(mid_table,
                                                                     by='user_phone_id',
                                                                     type='left')
        # TODO: сейчас требования к названию групп - чтобы таблички отличались в конфиге !!
        group_name = group.split('/')[-1]
        # .groupby('user_phone_id', 'pred', 'crypta_pred')\
        group_table \
            .groupby('user_phone_id') \
            .reduce(GmvDiscountsCommissionReducer(periods_list, ml_tags)) \
            .put(
            '{}phone_id_{}_{}_{}'.format(yt_work_dir, group_name, start_date,
                                         finish_date))
    return job


def orders_reducer(groups):
    for key, records in groups:
        results_cnt = Counter()
        for record in records:
            results_cnt['n_trips_{}'.format(record.order_tariff)] += 1
            results_cnt[
                'gmv_sum_{}'.format(record.order_tariff)] += record.user_cost
            results_cnt['gmv_sum_{}'.format(
                record.order_tariff)] += record.order_commission
            results_cnt['n_trips'] += 1
            results_cnt['gmv_sum'] += record.user_cost
            results_cnt['commission_sum'] += record.order_commission

            flag = int(record.order_tariff in ['business', 'comfortplus'])
            results_cnt['discounts_ml'] += flag * record.discount_value
            results_cnt['discounts_other'] += (1 - flag) * record.discount_value
        yield Record(key, **results_cnt)


def run_split_on_dates(job, yt_work_dir, group_list, ml_tags, periods_list,
                       start_date, finish_date):
    #     start_date = '2019-02-27'
    #     finish_date = '2019-04-26'
    mid_table = job.table('{}{}'.format(yt_work_dir,
                                        'tmp/dm_orders_{}_{}'.format(start_date,
                                                                     finish_date)))

    for group in group_list:
        # TODO: проверить, что после джойна не дублируются поля. иначе они перезапишутся
        group_table = job.table(group).project('user_phone_id').join(mid_table,
                                                                     by='user_phone_id',
                                                                     type='left')
        # TODO: сейчас требования к названию групп - чтобы таблички отличались в конфиге !!
        group_name = group.split('/')[-1]

        group_table \
            .filter(nf.and_(nf.custom(lambda x: x >= start_date, 'moscow_date'),
                            nf.custom(lambda x: x < finish_date,
                                      'moscow_date'))) \
            .groupby('moscow_date') \
            .reduce(orders_reducer) \
            .sort('moscow_date') \
            .put('{}date_{}_{}_{}'.format(yt_work_dir, group_name, start_date,
                                          finish_date))
    return job


if __name__ == '__main__':

    # python -m calc_exps --config-path tmp.json


    parser = argparse.ArgumentParser()
    parser.add_argument('--config-path', type=json_config_file, required=True)
    args = parser.parse_args()

    exp_periods = args.config_path['periods_list']
    # за какие месяцы брать dm_order
    start_exp_date = min(chain.from_iterable(exp_periods))
    finish_exp_date = max(chain.from_iterable(exp_periods))

    from projects.efficiency_metrics.project_config import get_project_cluster
    cluster = get_project_cluster()

    # cluster = clusters.yt.Hahn()  # get_project_cluster()

    # tests
    assert range_selector("2018-12-14", "2019-01-14",
                          '%Y-%m') == '{2018-12,2019-01}'
    assert range_selector("2018-12-14", "2019-01-14",
                          '%Y-%m-01') == '{2018-12-01,2019-01-01}'

    # experiment + after experiment
    job = cluster.job('Discounts metrics' + str(time.time()))
    job = job.env(
        bytes_decode_mode='strict',
        yt_spec_defaults={'max_failed_job_count': 1000}
    )
    # collect data from order_proc, dm_order
    job = collect_data(
        job=job,
        yt_work_dir=HAHN_DIR.format(args.config_path["experiment_name"]),
        start_date=start_exp_date,
        finish_date=finish_exp_date
    )
    job.run()

    # 1. by user_phone_id
    # (job, yt_work_dir, group_list, ml_tags, periods_list)
    job = cluster.job('Discounts metrics' + str(time.time()))
    job = job.env(
        bytes_decode_mode='strict',
        yt_spec_defaults={'max_failed_job_count': 1000}
    )
    job = run_split_on_users(
        job=job,
        yt_work_dir=HAHN_DIR.format(args.config_path["experiment_name"]),
        group_list=args.config_path['groups'],
        ml_tags=args.config_path["ml_tags"],
        periods_list=args.config_path['periods_list'],
        start_date=start_exp_date,
        finish_date=finish_exp_date
    )
    job.run()

    # 2. by date
    # job = cluster.job('Discounts metrics' + str(time.time()))
    # job = run_split_on_dates(
    #     job=job,
    #     yt_work_dir=HAHN_DIR.format(args.config_path["experiment_name"]),
    #     group_list=args.config_path['groups'],
    #     ml_tags=args.config_path["ml_tags"],
    #     periods_list=args.config_path['periods_list'],
    #     start_date=start_exp_date,
    #     finish_date=finish_exp_date
    # )
    # job.run()