# -*- coding: utf-8 -*-

import logging
from datetime import datetime, timedelta

import sandbox.common.types.client as ctc
from sandbox import sdk2
from sandbox.common.errors import TemporaryError
from sandbox.projects.rasp.bus.BusBaseTask import BusBaseTaskAutoResource, BusAutoResourceParameters
from sandbox.projects.rasp.utils import ISO_FORMAT
from sandbox.projects.rasp.utils.try_hard import try_hard
from sandbox.sandboxsdk import environments


class RaspBusPartnersStat(BusBaseTaskAutoResource):
    class Requirements(sdk2.Task.Requirements):
        client_tags = ctc.Tag.LXC
        ram = 2 * 1024
        environments = [
            environments.PipEnvironment('python-statface-client', use_wheel=False),
            environments.PipEnvironment('psycopg2-binary')
        ]

    class Parameters(BusAutoResourceParameters):
        with sdk2.parameters.Group('Postgres parameters') as postgres_params:
            postgres_user = sdk2.parameters.String('Postgres user', default='yandex_bus', required=True)
            postgres_password_vault_name = sdk2.parameters.String('Postgres vault name', required=True)
            postgres_host = sdk2.parameters.String(
                'Postgres database host, or list of hosts (for ex host1,host2,host3)',
                default='yandex-busdb01f.db.yandex.net,yandex-busdb01h.db.yandex.net,yandex-busdb01i.db.yandex.net',
                required=True,
            )
            postgres_port = sdk2.parameters.String('Postgres port', default=6432, required=True)
            postgres_dbname = sdk2.parameters.String('Postgres dbname', default='yandex_busdb', required=True)

        with sdk2.parameters.Group('Stat parameters') as stat_params:
            stat_token_vault_name = sdk2.parameters.String('statface robot oauth token vault name',
                                                           default='robot-rasp-statbox-token', required=True)
            stat_host = sdk2.parameters.String('stat host', default='upload.stat.yandex-team.ru', required=True)
            stat_report = sdk2.parameters.String(
                'stat report',
                default='Raspisanie/dapavlov/bus_partners',
                required=True,
            )
        with sdk2.parameters.Group('Data query parameters') as data_params:
            dt_from = sdk2.parameters.String('ISO datetime from (default=dt_to - 1h)', required=False, default='')
            dt_to = sdk2.parameters.String('ISO datetime to (default=now)', required=False, default='')

    @property
    def stat_report(self):
        if not hasattr(self, '_stat_report'):
            from statface_client import StatfaceClient
            stat_client_config = {
                'oauth_token': sdk2.Vault.data(self.Parameters.stat_token_vault_name),
                'host': self.Parameters.stat_host
            }
            stat_client = StatfaceClient(client_config=stat_client_config)
            report = stat_client.get_old_report(self.Parameters.stat_report)
            self._stat_report = report
        return self._stat_report

    @property
    def connection_string(self):
        return "dbname={dbname} user={dbuser} port={dbport} host={dbhost} sslmode=verify-full password={dbpassword} sslrootcert=allCAs.pem".format(
            dbname=self.Parameters.postgres_dbname,
            dbuser=self.Parameters.postgres_user,
            dbport=self.Parameters.postgres_port,
            dbhost=self.Parameters.postgres_host,
            dbpassword=sdk2.Vault.data(self.Parameters.postgres_password_vault_name))

    def prepare_certs(self):
        sdk2.helpers.subprocess.check_call(["curl", "https://crls.yandex.net/allCAs.pem", "-o", "allCAs.pem"])

    def upload_data(self, data):
        from statface_client import StatfaceClientRetriableError, StatfaceClientError
        try:
            @try_hard(max_retries=5, sleep_duration=5, retriable_exceptions=[StatfaceClientRetriableError])
            def _upload():
                self.stat_report.upload_data(scale='hourly', data=data)
            _upload()
        except StatfaceClientError as err:
            logging.exception(str(err))
            raise TemporaryError('StatFaceError')

    def get_confirmed_orders_by_partners(self, min_dt, max_dt):
        import psycopg2
        import psycopg2.extras
        with psycopg2.connect(self.connection_string) as db_connection:
            with db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
                query = '''
                select
                    sum(sub.order_price) as amount,
                    sum(sub.tickets_count)::int as tickets,
                    count(sub.*)::int as orders,
                    sub.partner as partner
                FROM
                    (
                    SELECT
                       (booking->>'price')::float  as order_price,
                       json_array_length(booking -> 'tickets') AS tickets_count,
                       (booking -> 'partner')::text as partner
                    FROM orders

                    WHERE
                        orders.status = 'confirmed'
                        AND orders.creation_ts > '{}'
                        AND orders.creation_ts <= '{}'
                    ) as sub
                GROUP BY sub.partner
                '''.format(min_dt, max_dt)
                curs.execute(query)
                rows = curs.fetchall()
                if rows:
                    new_dict = {}
                    total_tickets = 0
                    total_orders = 0
                    total_amount = 0
                    for row in rows:
                        total_tickets += row['tickets']
                        total_orders += row['orders']
                        total_amount += row['amount']
                        name = str(row['partner'])
                        new_dict[name] = {
                            'tickets': row['tickets'],
                            'orders': row['orders'],
                            'amount': row['amount']
                        }
                    new_dict['total'] = {
                        'tickets': total_tickets,
                        'orders': total_orders,
                        'amount': total_amount
                    }
                    return new_dict
                return {}

    def get_not_confirmed_orders_by_partners(self, min_dt, max_dt):
        import psycopg2
        import psycopg2.extras
        with psycopg2.connect(self.connection_string) as db_connection:
            with db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
                query = '''
                select
                    count(sub.*)::int as failed_orders,
                    sub.partner as partner
                FROM
                    (
                    SELECT
                       (booking -> 'partner')::text as partner
                    FROM orders

                    WHERE
                        orders.status != 'confirmed'
                        AND orders.creation_ts > '{}'
                        AND orders.creation_ts <= '{}'
                    ) as sub
                GROUP BY sub.partner
                '''.format(min_dt, max_dt)
                curs.execute(query)
                rows = curs.fetchall()
                if rows:
                    new_dict = {}
                    total_failed_orders = 0
                    for row in rows:
                        total_failed_orders += row['failed_orders']
                        name = str(row['partner'])
                        new_dict[name] = {
                            'failed_orders': row['failed_orders']
                        }
                    new_dict['total'] = {
                        'failed_orders': total_failed_orders
                    }
                    return new_dict
                return {}

    def make_stat_dict(self, dt, delta=timedelta(hours=1)):
        min_dt = dt
        max_dt = dt + delta
        done_orders = self.get_confirmed_orders_by_partners(min_dt, max_dt)
        failed_orders = self.get_not_confirmed_orders_by_partners(min_dt, max_dt)

        stat_data = []
        for key in done_orders:
            stat_dict = {
                'fielddate': str(dt+timedelta(hours=3)),
                'partner': key,
                'tickets': done_orders[key]['tickets'] if done_orders.get(key, False) else 0,
                'orders': done_orders[key]['orders'] if done_orders.get(key, False) else 0,
                'price': done_orders[key]['amount'] if done_orders.get(key, False) else 0,
                'failed_orders': failed_orders[key]['failed_orders'] if failed_orders.get(key, False) else 0
            }
            stat_data.append(stat_dict)
        return stat_data

    def on_execute(self):
        super(RaspBusPartnersStat, self).on_execute()
        self.prepare_certs()
        if self.Parameters.dt_to:
            to_hour = datetime.strptime(self.Parameters.dt_to, ISO_FORMAT)
        else:
            to_hour = datetime.today()
        to_hour = to_hour.replace(minute=0, second=0, microsecond=0)
        if self.Parameters.dt_from:
            from_hour = datetime.strptime(self.Parameters.dt_from, ISO_FORMAT).replace(minute=0, second=0, microsecond=0)
        else:
            from_hour = to_hour - timedelta(hours=1)
        data = []

        hour = from_hour
        while hour <= to_hour:
            data += self.make_stat_dict(hour)
            hour += timedelta(hours=1)

        if data:
            self.upload_data(data=data)
