import csv
import logging

from functools import partial

from travel.rasp.bus.library.carrier import CarrierType
from travel.rasp.bus.library.carrier.matching import create_carrier_and_matching, CarrierMatchingError
from travel.rasp.bus.spark_api.spark_search import find_entrepreneur_carriers, find_company_carriers, \
    CarrierSearchError, CarrierSearchErrorTooMany, create_spark_client
from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.supplier import Supplier
from travel.rasp.bus.spark_api.client import IN_MEMORY_CACHE

logger = logging.getLogger('carrier_loader')


class CarrierMatchingLoader:
    STOP_WORDS_COMAPNY = ('ООО', 'ОБЩЕСТВО С ОГРАНИЧЕННОЙ ОТВЕТСТВЕННОСТЬЮ', 'АКЦИОНЕРНОЕ ОБЩЕСТВО', 'АО',
                          'ОТКРЫТОЕ АКЦИОНЕРНОЕ ОБЩЕСТВО', 'ОАО', 'ПАО')
    STOP_WORDS_ENTREPRENEUR = ('ИП', 'ИНДИВИДУАЛЬНЫЙ ПРЕДПРИНИМАТЕЛЬ')

    def __init__(self, matching_fn, output_fn, csv_delimeter=';', csv_quotechar='"', dry_run=False):
        self._matching_fn = matching_fn
        self._output_fn = output_fn
        self._csv_quotechar = csv_quotechar
        self._csv_delimeter = csv_delimeter
        self._carriers_found = 0
        self._carriers_not_found = 0
        self._carriers_matched = 0
        self._carriers_dump = []
        self._dump_header = None
        self._suppliers = None
        self._dry_run = dry_run

    def run(self):
        logging.basicConfig(level=logging.INFO)
        if self._dry_run:
            logger.info('Dry run!')
        self._suppliers = self._load_suppliers()
        with create_spark_client(cache_db=IN_MEMORY_CACHE) as spark, session_scope() as db_session:
            self._search_attempt(spark, db_session, self._load_matchings())
        logger.info('Carriers found: %d', self._carriers_found)
        logger.info('Carriers not found : %d', self._carriers_not_found)
        logger.info('Carriers matched: %d', self._carriers_matched)
        if not self._dry_run:
            db_session.commit()
        else:
            logger.info('Dry run. Rollback')
            db_session.rollback()
        with open(self._output_fn, 'w') as outf:
            carrier_dump = csv.writer(outf)
            if self._dump_header:
                carrier_dump.writerow(self._dump_header + ['carrier_id', 'carrier_full_name', 'carrier_inn'])
            for row in self._carriers_dump:
                carrier_dump.writerow(row)

    def _load_matchings(self):
        with open(self._matching_fn) as matching_rows:
            matchings = csv.DictReader(matching_rows, delimiter=self._csv_delimeter, quotechar=self._csv_quotechar)
            for matching in matchings:
                yield matching

    def _search_attempt(self, spark, db_session, matchings):
        for matching in matchings:
            carrier_code = matching['INPUT:carrier']
            supplier_code = matching['INPUT:supplier']
            carrier_type = self._detect_matching_type(matching)
            logger.info('process %s, detected type: %s', carrier_code, carrier_type)
            carrier = self._find_carrier(spark, matching, carrier_type)
            if not carrier:
                self._carriers_not_found += 1
                logger.info('not found: %s', carrier_code)
                continue
            logger.info('found carrier! %s, %s, %s', carrier.id, carrier.full_name, carrier.inn)
            self._carriers_found += 1
            self._carriers_dump.append((*matching, carrier.id, carrier.full_name, carrier.inn))
            try:
                cid, matching_ids = create_carrier_and_matching(
                    carrier_type, carrier_code, carrier, (self._suppliers[supplier_code],), db_session=db_session)
                logger.info('carrier matched! carrier_id: %d, matching ids: %s', cid, matching_ids)
                self._carriers_matched += 1
            except CarrierMatchingError:
                logger.info('carrier NOT matched - already exists')

    def _find_carrier(self, spark, matching, carrier_type):
        carrier_code = matching['INPUT:carrier']
        reg_num = matching['OUTPUT:registration_number']
        legal_name = matching['OUTPUT:legal_name']
        short_name = matching['OUTPUT:short_legal_name']
        last_name = matching['OUTPUT:last_name']
        first_name = matching['OUTPUT:first_name']
        patronymic = matching['OUTPUT:patronymic']
        carrier = None
        if carrier_type == CarrierType.COMPANY:
            carrier = self._find_company(spark, reg_num, legal_name, short_name, carrier_code)
        elif carrier_type == CarrierType.ENTREPRENEUR:
            carrier = self._find_entrepreneur(spark, reg_num, last_name, first_name, patronymic, carrier_code)
        return carrier

    def _find_company(self, spark, reg_num, name, short_name, code):
        carrier = None
        try:
            logger.info('search company by code: %s', reg_num)
            (carrier,) = find_company_carriers(reg_num, is_name_search=False, records_limit=2, spark_client=spark)
            if carrier:
                return carrier
        except CarrierSearchErrorTooMany:
            logger.info('too many carriers found as company by reg number: %s', reg_num)
        except CarrierSearchError:
            logger.info('carrier not found as company by reg number: %s', reg_num)

        for search_value in map(partial(self._clean_name, self.STOP_WORDS_COMAPNY), (name, short_name, code)):
            try:
                logger.info('search company by name: %s', search_value)
                (carrier,) = find_company_carriers(search_value, is_name_search=True, records_limit=2,
                                                   spark_client=spark)
                if carrier:
                    return carrier
            except CarrierSearchErrorTooMany:
                logger.info('too many carriers found as company by name: %s', search_value)
            except CarrierSearchError:
                logger.info('carrier not found as company by name: %s', search_value)
        return carrier

    def _find_entrepreneur(self, spark, reg_num, last_name, first_name, middle_name, code):
        carrier = None
        try:
            logger.info('search entrepreneur by code: %s', reg_num)
            (carrier,) = find_entrepreneur_carriers(reg_num, is_name_search=False, records_limit=2, spark_client=spark)
            if carrier:
                return carrier
        except CarrierSearchErrorTooMany:
            logger.info('too many carriers found as entrepreneur by reg number: %s', reg_num)
        except CarrierSearchError:
            logger.info('carrier not found as entrepreneur by reg number: %s', reg_num)
        name = self._get_entrepreneur_name(last_name, first_name, middle_name)
        for search_value in map(partial(self._clean_name, self.STOP_WORDS_ENTREPRENEUR), (name, code)):
            try:
                logger.info('search entrepreneur by name: %s', search_value)
                (carrier,) = find_entrepreneur_carriers(search_value, is_name_search=True, records_limit=2,
                                                        spark_client=spark)
                if carrier:
                    return carrier
            except CarrierSearchErrorTooMany:
                logger.info('too many carriers found as entrepreneur by name: %s', search_value)
            except CarrierSearchError:
                logger.info('carrier not found as entrepreneur by name: %s', search_value)
        return carrier

    def _get_entrepreneur_name(self, last_name, first_name, middle_name):
        last_name_parts = last_name.split(' ')
        if len(last_name_parts) == 3:
            return last_name
        first_name_parts = first_name.split(' ')
        first_name_value = first_name
        if len(first_name_parts) > 1:
            first_name_value = first_name_parts[0]
        return ' '.join((last_name.strip(), first_name_value.strip(), middle_name.strip()))

    def _clean_name(self, stop_words, name):
        for word in stop_words:
            if name.upper().startswith(word.upper() + ' '):
                logger.info('name clearing: %s -> %s', name, name[len(word) + 1:])
                return name[len(word) + 1:]
            if name.upper().endswith(' ' + word.upper()):
                logger.info('name clearing: %s -> %s', name, name[:-len(word) - 1])
                return name[:-len(word) - 1]
        logger.info('name clearing: %s = %s', name, name)
        return name

    def _detect_matching_type(self, matching):
        reg_type = matching['OUTPUT:registration_type']
        last_name = matching['OUTPUT:last_name']
        carrier_type = CarrierType.COMPANY if reg_type == 'OGRN' else CarrierType.ENTREPRENEUR if reg_type == 'OGRNIP'\
            else None
        if carrier_type:
            return carrier_type
        if last_name:
            return CarrierType.ENTREPRENEUR
        return CarrierType.COMPANY

    def _load_suppliers(self):
        with session_scope() as session:
            suppliers = session.query(Supplier.id, Supplier.code).all()
            return {code: sid for sid, code in suppliers}
