import datetime
import json
import logging

from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import MultipleResultsFound
from yql.api.v1.client import YqlClient
from yql.client.parameter_value_builder import YqlParameterValueBuilder as ValueBuilder

from travel.library.python.rasp_vault.api import get_secret
from travel.rasp.bus.spark_api.client import IN_MEMORY_CACHE
from travel.rasp.bus.spark_api.spark_search import find_company_carriers, find_entrepreneur_carriers, \
    CarrierSearchError, CarrierSearchErrorTooMany, create_spark_client
from travel.rasp.bus.library.carrier import CarrierType
from travel.rasp.bus.library.carrier.matching import create_carrier_and_matching, CarrierMatchingError
from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.carrier import Carrier
from travel.rasp.bus.db.models.carrier_matching import CarrierMatching
from travel.rasp.bus.db.models.supplier import Supplier


logger = logging.getLogger('carrier_matcher')


class CarrierMatcher:
    SEARCH_LOGS_QUERY = """
DECLARE $start_date AS String;
DECLARE $end_date AS String;
DECLARE $supplier_ids AS List<String>;

SELECT Response
FROM RANGE(`home/logfeller/logs/travel-buses-production-worker-communication-log/1d`,
    $start_date, $end_date)
WHERE Code = 'EC_OK'
  AND Type = 'LRT_WORKER_SEARCH'
  AND Response != '{}'
  AND JSON_VALUE(CAST(Request AS JSON), '$.SupplierId') in $supplier_ids;
"""

    INN = 'Inn'
    REGISTER_NUMBER = 'RegisterNumber'
    REGISTER_TYPE = 'RegisterType'
    CARRIER_ATTRIBUTES = (INN, REGISTER_NUMBER, REGISTER_TYPE)

    class CarrierMatcherReport:
        def __init__(self):
            self.processed = 0
            self.matched = 0
            self.found_by_spark = 0
            self.created_and_matched = 0
            self.no_carrier_code = 0
            self.twins = 0
            self.already_matched = 0
            self.empty_carrier_model = 0
            self.empty_carrier_attributes = 0

        def report(self):
            logger.info('Processed carriers with data: %d', self.processed)
            logger.info('Existing carriers matched: %d', self.matched)
            logger.info('Carriers found by spark: %d', self.found_by_spark)
            logger.info('Carriers created and matched: %d', self.created_and_matched)
            logger.info('Rides without carrierCode: %d', self.no_carrier_code)
            logger.info('Carriers doubles rows: %d', self.twins)
            logger.info('Carriers already matched: %d', self.already_matched)
            logger.info('Rides with empty carrierModel: %d', self.empty_carrier_model)
            logger.info('Rides with invalid attributes in carrierModel: %d', self.empty_carrier_attributes)

    def __init__(self, yt_proxy, yql_token, suppliers, interval_end_date, days_interval, dry_run):
        self._yt_proxy = yt_proxy
        self._yql_token = get_secret(yql_token)
        self._supplier_names = suppliers
        self._suppliers = []
        self._interval_end_date = interval_end_date
        self._interval_start_date = self._interval_end_date - datetime.timedelta(days=days_interval - 1)
        self._dry_run = dry_run

        self._processed_carriers = set()
        self.stats = CarrierMatcher.CarrierMatcherReport()

    def run(self):
        logging.basicConfig(level=logging.INFO)
        logger.info('Start carrier matching')
        if self._dry_run:
            logger.info('Dry run!')
        logger.info('Parameters. suppliers: %s, interval: %s - %s', self._supplier_names, self._interval_start_date,
                    self._interval_end_date)
        self._suppliers = self._load_suppliers(self._supplier_names)
        with session_scope() as session:
            self._process_carriers(session, self._carrier_data(session, self._search_data()))

    def _search_data(self):
        parameters = {
            '$start_date': ValueBuilder.make_string(self._interval_start_date.isoformat()),
            '$end_date': ValueBuilder.make_string(self._interval_end_date.isoformat()),
            '$supplier_ids': ValueBuilder.make_list([ValueBuilder.make_string(str(sid)) for sid, _ in self._suppliers])
        }
        with YqlClient(db=self._yt_proxy, token=self._yql_token) as yql_client:
            request = yql_client.query(self.SEARCH_LOGS_QUERY, syntax_version=1)
            request.run(parameters=ValueBuilder.build_json_map(parameters))

            for table in request:
                for row in table.get_iterator():
                    yield row

    def _carrier_data(self, session, search_data):
        for row in search_data:
            rides = json.loads(row.pop())
            for ride in rides.get('Rides', []):
                supplier_id = ride['SupplierId']
                carrier_code = ride.get('CarrierCode')
                if not carrier_code:
                    self.stats.no_carrier_code += 1
                    continue
                if (supplier_id, carrier_code) in self._processed_carriers:
                    self.stats.twins += 1
                    continue
                if self._is_matched_carrier(session, supplier_id, carrier_code):
                    self.stats.already_matched += 1
                    continue
                carrier_model = ride.get('Carrier')
                if not carrier_model:
                    self.stats.empty_carrier_model += 1
                    continue
                carrier_attributes = self._attributes_from_carrier(carrier_model)
                if not carrier_attributes:
                    self.stats.empty_carrier_attributes += 1
                    continue
                yield supplier_id, carrier_code, carrier_attributes

    def _lookup_carrier_in_db(self, session, inn, register_number):
        if not inn and not register_number:
            return None
        filters = []
        if inn:
            filters.append(Carrier.inn == inn)
        if register_number:
            filters.append(Carrier.register_number == register_number)
        try:
            carrier = session.query(Carrier).filter(or_(*filters)).one_or_none()
        except MultipleResultsFound:
            logger.error('unexpected multiple carriers by inn, regnumber: %s', inn, register_number)
            return None
        return carrier

    def _process_carriers(self, session, carriers_data):
        with create_spark_client(cache_db=IN_MEMORY_CACHE) as spark:
            for supplier_id, carrier_code, carriers_attrs in carriers_data:
                self._processed_carriers.add((supplier_id, carrier_code))
                self.stats.processed += 1
                register_type = carriers_attrs.get(self.REGISTER_TYPE)
                inn = carriers_attrs.get(self.INN)
                register_number = carriers_attrs.get(self.REGISTER_NUMBER)
                unique_identifier = inn or register_number
                logger.info('look up carrier in DB. code=%s, inn=%s, regnumber=%s', carrier_code, inn, register_number)
                carrier = self._lookup_carrier_in_db(session, inn, register_number)
                if carrier:
                    try:
                        logger.info('carrier found in DB')
                        session.add(CarrierMatching(supplier_id=supplier_id, code=carrier_code, carrier_id=carrier.id))
                        session.flush()
                        self.stats.matched += 1
                        logger.info('matching created. supplier = %d, carrier = %s, carrier_id=%d',
                                    supplier_id, carrier_code, carrier.id)
                    except IntegrityError:
                        logger.info('matching already exists')
                    continue
                logger.info('carrier NOT found in DB by identifier: %s', unique_identifier)
                if register_type and register_type != 'REGISTER_TYPE_UNKNOWN':
                    logger.info('searching for known register type: %s, %s, %s, %s', supplier_id, carrier_code,
                                register_type, unique_identifier)
                    if register_type == 'REGISTER_TYPE_COMPANY':
                        carrier_type = CarrierType.COMPANY
                        try:
                            (carrier,) = find_company_carriers(unique_identifier, is_name_search=False, records_limit=2,
                                                               spark_client=spark)
                        except CarrierSearchErrorTooMany:
                            logger.info('too many carriers found as company')
                        except CarrierSearchError:
                            logger.info('carrier not found as company')
                    else:
                        carrier_type = CarrierType.ENTREPRENEUR
                        try:
                            (carrier, ) = find_entrepreneur_carriers(unique_identifier, is_name_search=False,
                                                                     records_limit=2, spark_client=spark)
                        except CarrierSearchErrorTooMany:
                            logger.info('too many carriers found as entrepreneur')
                        except CarrierSearchError:
                            logger.info('carrier not found as entrepreneur')
                else:
                    logger.info('searching for unknown register type: %s, %s, %s,',
                                supplier_id, carrier_code, unique_identifier)
                    carrier_type = CarrierType.COMPANY
                    carrier = None
                    try:
                        logger.info('searching for company')
                        (carrier, ) = find_company_carriers(unique_identifier, is_name_search=False, records_limit=2,
                                                            spark_client=spark)
                    except CarrierSearchErrorTooMany:
                        logger.info('too many carriers found as company')
                    except CarrierSearchError:
                        logger.info('carrier not found as company')
                    if not carrier:
                        carrier_type = CarrierType.ENTREPRENEUR
                        try:
                            logger.info('searching for entrepreneur')
                            (carrier, ) = find_entrepreneur_carriers(unique_identifier, is_name_search=False,
                                                                     records_limit=2, spark_client=spark)
                        except CarrierSearchErrorTooMany:
                            logger.info('too many carriers found as entrepreneur')
                        except CarrierSearchError:
                            logger.info('carrier not found as entrepreneur')

                if carrier:
                    self.stats.found_by_spark += 1
                    logger.info('got carrier from spark: %s %s', carrier, carrier.report)
                    try:
                        carrier_id, _ = create_carrier_and_matching(carrier_type, carrier_code, carrier, [supplier_id],
                                                                    db_session=session)
                        self.stats.created_and_matched += 1
                        logger.info('carrier and matching created. supplier = %d, carrier = %s, carrier_id=%d',
                                    supplier_id, carrier_code, carrier_id)
                    except CarrierMatchingError:
                        logger.info('carrier matching already exists')

            self.stats.report()
            if not self._dry_run:
                session.commit()
            else:
                logger.info('Dry run. Rollback')
                session.rollback()

    def _load_suppliers(self, supplier_names):
        with session_scope() as session:
            if not supplier_names:
                suppliers = session.query(Supplier.id, Supplier.code).all()
            else:
                suppliers = session.query(Supplier.id, Supplier.code).filter(Supplier.code.in_(supplier_names)).all()
                unknown_suppliers = set(supplier_names) - {s for _, s in suppliers}
                if unknown_suppliers:
                    raise ValueError('unknown suppliers in parameters: {}'.format(', '.join(unknown_suppliers)))
            return suppliers

    def _is_matched_carrier(self, session, supplier_id, carrier_code):
        is_matched = session.query(CarrierMatching).filter(CarrierMatching.supplier_id == supplier_id,
                                                           CarrierMatching.code == carrier_code)
        return session.query(is_matched.exists()).scalar()

    def _attributes_from_carrier(self, carrier_model):
        carrier_attributes = {}
        has_number = False
        for attr_name in self.CARRIER_ATTRIBUTES:
            attr = carrier_model.get(attr_name)
            if not attr:
                continue
            if attr_name in (self.INN, self.REGISTER_NUMBER):
                if not attr.isdigit():
                    continue
                has_number = True
            carrier_attributes[attr_name] = attr

        if not has_number:
            return None
        return carrier_attributes
