# -*- coding: utf-8 -*-

import logging
from datetime import datetime, timedelta
import pandas as pd

import sandbox.common.types.client as ctc
from sandbox import sdk2
from sandbox.common.errors import TemporaryError
from sandbox.projects.rasp.utils import ISO_FORMAT
from sandbox.projects.rasp.utils.email_notifications import EmailNotificationMixin, use_email_notification_params
from sandbox.projects.rasp.utils.try_hard import try_hard
from sandbox.sandboxsdk import environments

UTM_SOURCE_LIST = ['rasp', 'suburbans', 'wizard', 'zhd_google', 'email', 'total', 'None', 'yamain', 'gclid']


class RaspZDtoStat(sdk2.Task, EmailNotificationMixin):
    class Requirements(sdk2.Task.Requirements):
        client_tags = ctc.Tag.LXC
        ram = 2 * 1024
        environments = [
            environments.PipEnvironment('python-statface-client', use_wheel=False),
        ]

    class Parameters(sdk2.Task.Parameters):
        with sdk2.parameters.Group('Stat parameters') as stat_params:
            stat_token_vault_name = sdk2.parameters.String('statface robot oauth token vault name',
                                                           default='robot-rasp-statbox-token', required=True)
            stat_host = sdk2.parameters.String('stat host', default='upload.stat.yandex-team.ru', required=True)
            stat_report = sdk2.parameters.String(
                'stat report',
                default='Raspisanie/dapavlov/zd_tickets_details',
                required=True,
            )
        with sdk2.parameters.Group('Data query parameters') as data_params:
            dt_from = sdk2.parameters.String('ISO datetime from (default=dt_to - 1h)', required=False, default='')
            dt_to = sdk2.parameters.String('ISO datetime to (default=now)', required=False, default='')

        _email_notification_params = use_email_notification_params()

    @property
    def db(self):
        if not hasattr(self, '_db'):
            from sandbox.projects.rasp.analytics.utils import get_train_purchase_db
            self._db = get_train_purchase_db()
        return self._db

    @property
    def stat_report(self):
        if not hasattr(self, '_stat_report'):
            from statface_client import StatfaceClient
            stat_client_config = {
                'oauth_token': sdk2.Vault.data(self.Parameters.stat_token_vault_name),
                'host': self.Parameters.stat_host
            }
            stat_client = StatfaceClient(client_config=stat_client_config)
            report = stat_client.get_old_report(self.Parameters.stat_report)
            self._stat_report = report
        return self._stat_report

    def upload_data(self, data):
        from statface_client import StatfaceClientRetriableError, StatfaceClientError
        try:
            @try_hard(max_retries=5, sleep_duration=5, retriable_exceptions=[StatfaceClientRetriableError])
            def _upload():
                self.stat_report.upload_data(scale='hourly', data=data)
            _upload()
        except StatfaceClientError as err:
            logging.exception(str(err))
            raise TemporaryError('StatFaceError')

    def get_failed_orders_from_mongo(self, min_dt, max_dt, platform):
        pipeline = [
            {"$match": {
                "status": {"$ne": "done"},
                    # "$in": [
                    #     'cancelled',
                    #    'payment_failed',
                    #     'confirm_failed',
                    #     'payment_failed',
                    #    'payment_outdated',
                    #    'start_payment_failed',
                    #]
                # },
                "source.device": "desktop" if platform == "desktop" else {"$ne": "desktop"},
                "reserved_to": {
                    '$gte': min_dt,
                    '$lt': max_dt
                },
                "source.gclid": None,
            }},
            {"$group": {
                "_id": "$source.utm_source",
                "count_failed_orders": {"$sum": 1}
            }}
        ]
        items = list(self.db.train_order.aggregate(pipeline))

        pipeline2 = [
            {"$match": {
                "status": {"$ne": "done"},
                "source.device": "desktop" if platform == "desktop" else {"$ne": "desktop"},
                "reserved_to": {
                    '$gte': min_dt,
                    '$lt': max_dt,
                },
                "source.gclid": {"$ne": None},
            }},
            {"$group": {
                "_id": "$source.utm_source",
                "count_failed_orders": {"$sum": 1}
            }}
        ]
        items2 = list(self.db.train_order.aggregate(pipeline2))

        new_dict = {}
        total = 0
        if items:
            for item in items:
                    total += item['count_failed_orders']
                    name = str(item['_id'])
                    new_dict[name] = item
        if items2:
            total_gclid = 0
            for item in items2:
                    total += item['count_failed_orders']
                    total_gclid += item['count_failed_orders']
            new_dict['gclid'] = {'count_failed_orders': total_gclid}

        new_dict['total'] = {'count_failed_orders': total}
        return new_dict

    def get_done_orders_from_mongo(self, min_dt, max_dt, platform):
        pipeline = [
            {"$match": {
                "status": "done",
                "source.device": "desktop" if platform == "desktop" else {"$ne": "desktop"},
                "finished_at": {
                    '$gte': min_dt,
                    '$lt': max_dt
                },
                "source.gclid": None,
            }},
            {"$unwind": "$passengers"},
            {"$unwind": "$passengers.tickets"},
            {"$unwind": "$passengers.tickets.payment"},
            {"$group": {
                "_id": "$source.utm_source",
                "count_tickets": {"$sum": 1},
                "order_ids": {"$addToSet": "$_id"},
                "sum_fee": {"$sum": "$passengers.tickets.payment.fee"},
                "sum_amount": {"$sum": "$passengers.tickets.payment.amount"},
                "sum_partner_fee": {"$sum": "$passengers.tickets.payment.partner_fee"}
            }},
            # De-normalise orders from the set
            {"$unwind": "$order_ids"},
            {"$group": {
                "_id": "$_id",
                "count_orders": {"$sum": 1},
                "count_tickets": {"$first": "$count_tickets"},
                "sum_amount": {"$first": "$sum_amount"},
                "sum_fee": {"$first": "$sum_fee"},
                "sum_partner_fee": {"$first": "$sum_partner_fee"}
            }}
        ]
        items = list(self.db.train_order.aggregate(pipeline))

        pipeline2 = [
            {"$match": {
                "status": "done",
                "source.device": "desktop" if platform == "desktop" else {"$ne": "desktop"},
                "finished_at": {
                    '$gte': min_dt,
                    '$lt': max_dt
                },
                "source.gclid": {"$ne": None },
            }},
            {"$unwind": "$passengers"},
            {"$unwind": "$passengers.tickets"},
            {"$unwind": "$passengers.tickets.payment"},
            {"$group": {
                "_id": "$source.utm_source",
                "count_tickets": {"$sum": 1},
                "order_ids": {"$addToSet": "$_id"},
                "sum_fee": {"$sum": "$passengers.tickets.payment.fee"},
                "sum_amount": {"$sum": "$passengers.tickets.payment.amount"},
                "sum_partner_fee": {"$sum": "$passengers.tickets.payment.partner_fee"}
            }},
            # De-normalise orders from the set
            {"$unwind": "$order_ids"},
            {"$group": {
                "_id": "$_id",
                "count_orders": {"$sum": 1},
                "count_tickets": {"$first": "$count_tickets"},
                "sum_amount": {"$first": "$sum_amount"},
                "sum_fee": {"$first": "$sum_fee"},
                "sum_partner_fee": {"$first": "$sum_partner_fee"}
            }}
        ]
        items2 = list(self.db.train_order.aggregate(pipeline2))
        new_dict = {}
        total_tickets = 0
        total_orders = 0
        total_fee = 0
        total_partner_fee = 0
        total_amount = 0
        if items:
            for item in items:
                total_tickets += item['count_tickets']
                total_orders += item['count_orders']
                total_fee += item['sum_fee']
                total_partner_fee += item['sum_partner_fee']
                total_amount += item['sum_amount']
                name = str(item['_id'])
                new_dict[name] = item
        if items2:
            total_tickets_gclid = 0
            total_orders_gclid = 0
            total_fee_gclid = 0
            total_partner_fee_gclid = 0
            total_amount_gclid = 0
            for item in items2:
                total_tickets += item['count_tickets']
                total_orders += item['count_orders']
                total_fee += item['sum_fee']
                total_partner_fee += item['sum_partner_fee']
                total_amount += item['sum_amount']
                total_tickets_gclid += item['count_tickets']
                total_orders_gclid += item['count_orders']
                total_fee_gclid += item['sum_fee']
                total_partner_fee_gclid += item['sum_partner_fee']
                total_amount_gclid += item['sum_amount']
            new_dict['gclid'] = {
                'count_tickets': total_tickets_gclid,
                'count_orders': total_orders_gclid,
                'sum_amount': total_amount_gclid,
                'sum_partner_fee': total_partner_fee_gclid,
                'sum_fee': total_fee_gclid,
            }

        new_dict['total'] = {
            'count_tickets': total_tickets,
            'count_orders': total_orders,
            'sum_amount': total_amount,
            'sum_partner_fee': total_partner_fee,
            'sum_fee': total_fee
        }
        return new_dict

    def get_insurance_info_from_mongo(self, min_dt, max_dt, platform):
        pipeline = [
            {"$match": {
                "status": "done",
                "source.device": "desktop" if platform == "desktop" else {"$ne": "desktop"},
                "passengers.insurance.trust_order_id": {'$ne': None},
                "finished_at": {
                    '$gte': min_dt,
                    '$lt': max_dt
                },
                "source.gclid": None,
            }},
            {"$unwind": "$passengers"},
            {"$unwind": "$passengers.tickets"},
            {"$unwind": "$passengers.tickets.payment"},
            {"$group": {
                "_id": "$source.utm_source",
                "order_ids": {"$addToSet": "$_id"},
                "count_tickets": {"$sum": 1},
                "sum_insurance_amount": {"$sum": "$passengers.insurance.amount"}
            }},
            # De-normalise orders from the set
            {"$unwind": "$order_ids"},
            {"$group": {
                "_id": "$_id",
                "ins_orders": {"$sum": 1},
                "ins_tickets": {"$first": "$count_tickets"},
                "ins_amount": {"$first": "$sum_insurance_amount"}
            }}
        ]
        items = list(self.db.train_order.aggregate(pipeline))

        pipeline2 = [
            {"$match": {
                "status": "done",
                "source.device": "desktop" if platform =="desktop" else {"$ne": "desktop"},
                "passengers.insurance.trust_order_id": {'$ne': None},
                "finished_at": {
                    '$gte': min_dt,
                    '$lt': max_dt
                },
                "source.gclid": {"$ne": None},
            }},
            {"$unwind": "$passengers"},
            {"$unwind": "$passengers.tickets"},
            {"$unwind": "$passengers.tickets.payment"},
            {"$group": {
                "_id": "$source.utm_source",
                "order_ids": {"$addToSet": "$_id"},
                "count_tickets": {"$sum": 1},
                "sum_insurance_amount": {"$sum": "$passengers.insurance.amount"}
            }},
            # De-normalise orders from the set
            {"$unwind": "$order_ids"},
            {"$group": {
                "_id": "$_id",
                "ins_orders": {"$sum": 1},
                "ins_tickets": {"$first": "$count_tickets"},
                "ins_amount": {"$first": "$sum_insurance_amount"}
            }}
        ]
        items2 = list(self.db.train_order.aggregate(pipeline2))

        new_dict = {}
        ins_tickets = 0
        ins_orders = 0
        ins_amount = 0
        if items:
            for item in items:
                ins_tickets += item['ins_tickets']
                ins_orders += item['ins_orders']
                ins_amount += item['ins_amount']
                name = str(item['_id'])
                new_dict[name] = item
        if items2:
            ins_tickets_gclid = 0
            ins_orders_gclid = 0
            ins_amount_gclid = 0
            for item in items2:
                ins_tickets += item['ins_tickets']
                ins_orders += item['ins_orders']
                ins_amount += item['ins_amount']
                ins_tickets_gclid += item['ins_tickets']
                ins_orders_gclid += item['ins_orders']
                ins_amount_gclid += item['ins_amount']
            new_dict['gclid'] = {
                'ins_tickets': ins_tickets_gclid,
                'ins_orders': ins_orders_gclid,
                'ins_amount': ins_amount_gclid,
            }

        new_dict['total'] = {
                'ins_tickets': ins_tickets,
                'ins_orders': ins_orders,
                'ins_amount': ins_amount,
        }
        return new_dict

    def get_mongo_info_dict_by_platform(self, source_device, is_mobile,
                                        dt, delta=timedelta(hours=1)):
        min_dt = dt - timedelta(hours=3)
        max_dt = dt + delta - timedelta(hours=3)
        failed_orders = self.get_failed_orders_from_mongo(min_dt, max_dt, source_device)
        done_orders = self.get_done_orders_from_mongo(min_dt, max_dt, source_device)
        insurance = self.get_insurance_info_from_mongo(min_dt, max_dt, source_device)

        stat_data = []
        for source in UTM_SOURCE_LIST:
            stat_dict = {
                'fielddate': str(dt),
                'ui': source_device,
                'source': source,
                'failed_orders': failed_orders[source]['count_failed_orders'] if failed_orders.get(source, False) else 0,
                'tickets': done_orders[source]['count_tickets'] if done_orders.get(source, False) else 0,
                'orders': done_orders[source]['count_orders'] if done_orders.get(source, False) else 0,
                'price': done_orders[source]['sum_amount'] if done_orders.get(source, False) else 0,
                'fee': done_orders[source]['sum_fee'] if done_orders.get(source, False) else 0,
                'partner_fee': done_orders[source]['sum_partner_fee'] if done_orders.get(source, False) else 0,
                'insurance_amount': insurance[source]['ins_amount'] if insurance.get(source, False) else 0,
                'tickets_w_insurance': insurance[source]['ins_tickets'] if insurance.get(source, False) else 0,
                'orders_w_insurance': insurance[source]['ins_orders'] if insurance.get(source, False) else 0,
            }
            stat_data.append(stat_dict)

        return stat_data

    def on_execute(self):
        if self.Parameters.dt_to:
            to_hour = datetime.strptime(self.Parameters.dt_to, ISO_FORMAT)
        else:
            to_hour = datetime.today()
        to_hour = to_hour.replace(minute=0, second=0, microsecond=0)
        if self.Parameters.dt_from:
            from_hour = datetime.strptime(self.Parameters.dt_from, ISO_FORMAT).replace(minute=0, second=0, microsecond=0)
        else:
            from_hour = to_hour - timedelta(hours=1)
        data = []

        hour = from_hour
        while hour <= to_hour:
            desktop_data = self.get_mongo_info_dict_by_platform(source_device='desktop', is_mobile=False, dt=hour)
            touch_data = self.get_mongo_info_dict_by_platform(source_device='touch', is_mobile=True, dt=hour)
            data += desktop_data + touch_data
            hour += timedelta(hours=1)

        self.upload_data(data=data)

    def on_save(self):
        super(RaspZDtoStat, self).on_save()
        self.add_email_notifications()
