# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import datetime
import logging
import math

from requests import HTTPError
from sqlalchemy import func, cast
from sqlalchemy.dialects.postgresql import JSON

from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.order import Order
from travel.rasp.bus.db.models.supplier import Supplier
from travel.library.python.solomon_push_client import SolomonPushClient

log = logging.getLogger(__name__)

STATUS_MAPPER = {
    True: 'OK',
    False: 'CRIT'
}


def median(values):
    val_count = len(values)
    if not values:
        return 0
    if val_count == 1:
        return values[0]
    sorted_values = sorted(values)
    if val_count % 2 == 0:
        mdn = (sorted_values[val_count // 2] + sorted_values[val_count // 2 - 1]) / 2.0
    else:
        mdn = sorted_values[val_count // 2]
    return mdn


class SalesLevel:
    NORM = 'normal'
    SMALL = 'small'


class SupplierSales:
    def __init__(self, supplier):
        self.name = supplier
        self.sales = 0
        self.hour_rate = 0
        self.sales_level = None

    def __repr__(self):
        return "SupplierSales(name: {}, sales: {}, hour_rate: {}, sales_level: {})" \
            .format(self.name, self.sales, round(self.hour_rate, 1), self.sales_level)


class SalesMonitor:
    def __init__(self, token, suppliers, threshold, threshold_sensitivity, history_values_num,
                 small_day_sales, sales_estimation_days, small_sales_threshold_ratio,
                 environment='testing', check_time=None, dry_run=False):

        log.setLevel(logging.INFO)
        self.history_values_num = history_values_num
        self.threshold = threshold
        self.threshold_sensitivity = threshold_sensitivity
        self.small_day_sales = small_day_sales
        self.sales_estimation_days = sales_estimation_days
        self.check_datetime = check_time or datetime.datetime.utcnow()
        self.small_sales_threshold_ratio = small_sales_threshold_ratio
        self.solomon = SolomonPushClient('bus', environment, 'salesmon', token)
        self.suppliers = self._get_suppliers(suppliers)
        self.dry_run = dry_run

    def _get_supplier_sales_volume(self, supplier):
        from_datetime = self.check_datetime - datetime.timedelta(days=self.sales_estimation_days)
        return BusesDB.get_supplier_sales(supplier.name, from_datetime, self.check_datetime)

    def _get_period_sales(self, supplier=None):
        res = []
        from_period = self.check_datetime - datetime.timedelta(hours=1)
        to_period = self.check_datetime
        for i in range(self.history_values_num + 1):
            sales = BusesDB.get_supplier_sales(supplier.name, from_period, to_period) if supplier else \
                BusesDB.get_sales(from_period, to_period)
            res.append(sales)
            from_period = from_period - datetime.timedelta(weeks=1)
            to_period = to_period - datetime.timedelta(weeks=1)
        log.info("got {} sales: {}, from time: {}, to time: {}".
                 format(supplier.name if supplier else 'total', res, from_period, self.check_datetime))
        return res

    def _check_all_sales(self):
        sales = self._get_period_sales()
        current_sales = sales[0]
        sales = sales[1:]
        log.info("check_all_sales. sales before: %s, current: %d", sales, current_sales)
        return self._check_sales(sales, current_sales)

    def _check_supplier_sales(self, supplier):
        log.info("_check_supplier_sales. checking %s", supplier.name)
        sales = self._get_period_sales(supplier)
        current_sales = sales[0]
        sales = sales[1:]
        return self._check_sales(sales, current_sales)

    def _check_sales(self, sales, current_sales):
        median_sales = median(sales)
        threshold = median_sales * self.threshold
        log.info("_check_sales. threshold_sensitivity: %d, median: %d, sales: %d, threshold: %d",
                 self.threshold_sensitivity, median_sales, current_sales, threshold)
        if threshold > current_sales and median_sales > self.threshold_sensitivity:
            return False
        return True

    def _get_supplier_data(self, supplier_name):
        supplier = SupplierSales(supplier_name)
        supplier.sales = self._get_supplier_sales_volume(supplier)
        supplier.hour_rate = supplier.sales / (self.sales_estimation_days * 24)
        supplier.sales_level = self._get_supplier_sales_level(supplier)
        log.info("supplier data loaded: %s", supplier)
        return supplier

    def _get_suppliers(self, supplier_names):
        if not supplier_names:
            supplier_names = BusesDB.get_supplier_names()
        else:
            valid_names = []
            for supplier_name in supplier_names:
                valid_supplier = BusesDB.is_supplier(supplier_name)
                if not valid_supplier:
                    raise ValueError("unknown supplier name or supplier is hidden: {}".format(supplier_name))
                valid_names.append(supplier_name)
            supplier_names = valid_names

        suppliers = []
        for supplier_name in supplier_names:
            supplier = self._get_supplier_data(supplier_name)
            suppliers.append(supplier)
        return suppliers

    def _check_small_sales(self, supplier):
        if supplier.sales <= 0:
            log.info("_check_small_sales. supplier %s - zero sales", supplier.name)
            return False

        must_do_sale_interval = math.ceil(1.0 / (supplier.hour_rate * 24)) * self.small_sales_threshold_ratio
        days_to_check = min(must_do_sale_interval, self.sales_estimation_days)

        from_period = self.check_datetime - datetime.timedelta(days=days_to_check)
        supplier_sales = BusesDB.get_supplier_sales(supplier.name, from_period, self.check_datetime)
        log.info("_check_small_sales. %s, must_do_sale_interval: %.2f day(s),"
                 " sales in period: %d, days to check: %d",
                 supplier, must_do_sale_interval, supplier_sales, days_to_check)
        if supplier_sales <= 0:
            return False
        return True

    def run(self):
        if self.dry_run:
            log.info("Dry run!")
        log.info("""parameters:
\thistory_values_num: %s
\tthreshold: %f
\tthreshold_sensitivity: %d
\tsmall_day_sales: %d
\tsales_estimation_days: %d
\tcheck_datetime: %s
\tsmall_sales_threshold_ratio: %f""",
                 self.history_values_num,
                 self.threshold,
                 self.threshold_sensitivity,
                 self.small_day_sales,
                 self.sales_estimation_days,
                 self.check_datetime,
                 self.small_sales_threshold_ratio,
                 )
        log.info("suppliers: %s", self.suppliers)
        # monitor total sales
        sales_ok = self._check_all_sales()
        if not self.dry_run:
            self.push_to_solomon('total', sales_ok)
        log.info("Total sales is OK: %s", sales_ok)

        # monitor suppliers sales
        for supplier in self.suppliers:
            if supplier.sales_level == SalesLevel.NORM:
                supplier_sales_ok = self._check_supplier_sales(supplier)
            else:
                supplier_sales_ok = self._check_small_sales(supplier)
            if not self.dry_run:
                self.push_to_solomon(supplier.name, supplier_sales_ok)
            log.info("%s sales is Ok: %s", supplier.name, supplier_sales_ok)

    def _get_supplier_sales_level(self, supplier):
        day_sales = supplier.sales / self.sales_estimation_days
        if day_sales <= self.small_day_sales:
            return SalesLevel.SMALL
        return SalesLevel.NORM

    def push_to_solomon(self, source, status):
        self.solomon.send(STATUS_MAPPER[status], 1.0, source=source)
        try:
            self.solomon.upload()
        except HTTPError:
            log.exception("Error pushing to solomon")


class BusesDB:

    @staticmethod
    def get_sales(from_datetime, to_datetime):
        with session_scope() as session:
            value = session.query(func.count()).filter(
                Order.status == 'confirmed',
                func.timezone('UTC', Order.creation_ts) >= from_datetime,
                func.timezone('UTC', Order.creation_ts) <= to_datetime,
            ).scalar()
            return value

    @staticmethod
    def get_supplier_sales(supplier, from_datetime, to_datetime):
        with session_scope() as session:
            value = session.query(func.count()).filter(
                Order.status == 'confirmed',
                cast(Order.booking, JSON)['partner'].astext == supplier,
                func.timezone('UTC', Order.creation_ts) >= from_datetime,
                func.timezone('UTC', Order.creation_ts) <= to_datetime,
            ).scalar()
            return value

    @staticmethod
    def is_supplier(supplier_name):
        with session_scope() as session:
            is_supplier = session.query(
                session.query(Supplier.code).filter(Supplier.hidden.isnot(True), Supplier.code == supplier_name)
                    .exists()
            ).scalar()
        return is_supplier

    @staticmethod
    def get_supplier_names():
        with session_scope() as session:
            supplier_names = session.query(Supplier.code).filter(Supplier.hidden.isnot(True)).all()
        return [name for (name,) in supplier_names]
