# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.supplier import Supplier

from travel.rasp.bus.scripts.automatcher.policy import MatchedPolicy


class BaseScenario(object):
    name = None
    matched_policy = None
    supplier = None
    _supplier_id = None
    point_type_policy = None
    report = None

    def __init__(self, **params):
        self.report = []

    def _run(self, point):
        raise NotImplementedError

    def run(self, point):
        if self.matched_policy is None or self.point_type_policy is None or self.name is None:
            raise ValueError("class {} has invalid base values. policy: {}, point_type_policy: {}, name: {}".
                             format(self.__class__.__name__, self.matched_policy, self.point_type_policy, self.name))
        return self._run(point)

    def get_name(self):
        if not self.name:
            raise ValueError('invalid name for scenario: {}'.format(self.name))
        return self.name

    def get_policy(self):
        return self.matched_policy, self.point_type_policy

    def get_scenario_supplier_id(self):
        if self.supplier is not None and self._supplier_id is None:
            with session_scope() as session:
                supplier = session.query(Supplier).filter(Supplier.code == self.supplier).one_or_none()
                if not supplier:
                    raise ValueError("cannot load supplier from db: {}".format(self.supplier))
                self._supplier_id = supplier.id
        return self._supplier_id

    def get_supplier_id(self, supplier_name):
        with session_scope() as session:
            supplier = session.query(Supplier.id).filter(Supplier.code == supplier_name).one_or_none()
            if not supplier:
                raise ValueError("cannot load supplier from db: {}".format(self.supplier))
        return supplier.id

    def get_suppliers_list(self):
        with session_scope() as session:
            suppliers = session.query(Supplier.id, Supplier.code).filter(Supplier.hidden.is_(False)).all()
            if not suppliers:
                raise ValueError("cannot load suppliers from db")
        return suppliers

    def get_report(self):
        return self.report

    def postprocess(self):
        pass

    def get_config(self, params, cfg_name=None):
        key = cfg_name or self.name
        config = params.get(key)
        if not config:
            raise ValueError('got empty params, unable to get a config by key "{}" for the scenario: {}'.format(
                key, self.name))
        return config


class BaseMatcher(BaseScenario):
    matched_policy = MatchedPolicy.NOT_MATCHED


class BaseUnmatcher(BaseScenario):
    matched_policy = MatchedPolicy.MATCHED
