#coding=utf-8
from nile.api.v1 import clusters
from nile.api.v1 import filters as nf
from nile.api.v1 import aggregators as na
from nile.api.v1 import extractors as ne
from nile.api.v1 import Record

import json, pickle
from collections import Counter, defaultdict, deque, OrderedDict
import numpy as np
import heapq
import time
import random
import datetime
# from itertools import imap, izip, product
from dateutil.rrule import rrule, MONTHLY
import pytz
import re, sys, argparse
import statsmodels.api as sm
import pandas as pd
from scipy import stats
import warnings
from dateutil import rrule
# import itertools
import os

from projects.efficiency_metrics.project_config import get_project_cluster

def list_to_dict(value):
    if type(value) is list:
        return {
            k: list_to_dict(v)
            for k, v in value
        }
    else:
        return value


def parse_datetime(value):
    try:
        date = datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f")
    except ValueError:
        try:
            date = datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
        except ValueError:
            date = datetime.datetime.strptime(value, "%Y-%m-%d")
    return date


def orders_mapper():
    def mapper(records):
        for record in records:

            if record.doc['status'] == 'finished' and record.doc[
                'taxi_status'] == 'complete':

                # tz = cities_2_tz[record.doc['city']]
                tz=3

                doc_cost = record.doc.get('cost', 0)

                doc_class_list = record.doc.get('request', {}).get('class')
                # user_cost
                user_cost = None
                coupon_percent = record.doc.get('coupon', {}).get('percent', 0)
                if coupon_percent > 0:
                    user_cost = (1 - coupon_percent) * 1. / 100 * doc_cost
                else:
                    user_cost = doc_cost

                # price
                price = None
                discount_price = record.doc.get('discount', {}).get('price', 0)
                if discount_price > doc_cost:
                    price = discount_price
                else:
                    price = doc_cost

                # discount
                discount = None
                if (discount_price > 0) and (
                        discount_price - doc_cost > 0) and coupon_percent > 0:
                    discount = discount_price - doc_cost + coupon_percent * 1. / 100 * doc_cost
                elif (discount_price > 0) and (discount_price - doc_cost > 0):
                    discount = discount_price - doc_cost
                elif coupon_percent > 0:
                    discount = coupon_percent * doc_cost * 1. / 100
                else:
                    discount = 0

                by_classes = record.doc.get('discount', {}).get('by_classes',
                                                                [])

                for by_classes_dict in by_classes:
                    ya_plus = 0
                    price_modifiers = record.doc.get('price_modifiers', {}).get(
                        'items', [])
                    for pr_modifier in price_modifiers:
                        if pr_modifier.get('reason') == 'ya_plus':
                            ya_plus = 1

                    yield Record(
                        reason=by_classes_dict.get('reason'),
                        tag=by_classes_dict.get('description'),
                        user_cost=user_cost,
                        price=price,
                        discount=discount,
                        user_phone_id=record.doc['user_phone_id'],
                        user_id=record.doc['user_id'],
                        order_id=record.doc['_id'],
                        ya_plus=ya_plus,
                        local_dt_h=(parse_datetime(
                            record.created) + datetime.timedelta(
                            hours=tz)).strftime('%Y-%m-%d %H'),
                        doc_class_list=doc_class_list
                    )

    return mapper


def datetime_to_timestamp(dt):
    return int((dt - datetime.datetime(1970, 1, 1)).total_seconds())


def dict_to_list(d):
    return map(list, d.items())


def dm_order_range(dt_start, dt_finish):
    return [
        '//home/taxi-dwh/summary/dm_order/{}'.format(dt.strftime('%Y-%m'))
        for dt in rrule.rrule(rrule.MONTHLY, dtstart=dt_start.replace(day=1),
                              until=dt_finish.replace(day=1))
    ]


HAHN_DIR = '//home/taxi_ml/comfort/feats/'

cluster = get_project_cluster()

START_DATE = '2021-04-01'
FINISH_DATE = '2021-04-07'

cities_set = {'Москва'}
cities_2_tz = {'Москва' : 3}


m_range = dm_order_range(parse_datetime(START_DATE), parse_datetime(FINISH_DATE))


job = cluster.job('Discounts moscow ' + str(time.time()))
job = job.env(
    bytes_decode_mode='strict',
    yt_spec_defaults={'max_failed_job_count': 1000}
)

# # m_range = ['//home/taxi-dwh/raw/mdb/orders/2019-06-01',
# #            '//home/taxi-dwh/raw/mdb/orders/2019-07-01']
#
m_range = ['//home/taxi-dwh/raw/mdb/orders/2021-03-01',
           '//home/taxi-dwh/raw/mdb/orders/2021-04-01']

# m_range = [
#     '//home/taxi-dwh/raw/mdb/order_proc/2021-04-01'
# ]

dm_orders_table = job.table('{{{}}}'.format(','.join(m_range)), ignore_missing=True)

orders_table = dm_orders_table\
    .filter(nf.and_(
        nf.custom(lambda x: START_DATE <= x < FINISH_DATE , "created")
    ))\
    .map(orders_mapper())\
    .put(HAHN_DIR + 'tmp_{}_{}'.format(START_DATE, FINISH_DATE))

job.run()

df = cluster.read(HAHN_DIR + 'tmp_{}_{}'.format(START_DATE, FINISH_DATE)).as_dataframe()

print (df.groupby('tag')['discount'].sum())