import dataclasses
import logging
from contextlib import nullcontext

from travel.rasp.bus.library.carrier import CarrierType
from travel.rasp.bus.settings import Settings
from travel.rasp.bus.spark_api.client import SparkClient, SparkConfig

log = logging.getLogger(__name__)


def create_spark_client(cache_db=Settings.Spark.CACHE_DB):
    return SparkClient.create(SparkConfig(
        host=Settings.Spark.HOST,
        login=Settings.Spark.LOGIN,
        password=Settings.Spark.PASSWORD,
        cache_db=cache_db
    ))


class CarrierSearchError(Exception):
    pass


class CarrierSearchErrorTooMany(CarrierSearchError):
    pass


# 49.3 - Деятельность прочего сухопутного пассажирского транспорта
# Эта группировка включает:
# - пассажирские перевозки наземным транспортом, кроме перевозок, осуществляемых железнодорожным транспортом
# Эта группировка также включает:
# - железнодорожный транспорт, если он является частью городских или пригородных транспортных систем
OKVED_TRANSPORT_PREFIX = '49.3'

# 12300 — Общества с ограниченной ответственностью
OKOPF_OOO_CODE = '12300'


def _okved_predicate(okveds, okved_prefix=OKVED_TRANSPORT_PREFIX):
    okved_codes = set(okved.code for okved in okveds)
    valid_codes = set(okved_code for okved_code in okved_codes if okved_code.startswith(okved_prefix))
    log.debug('Valid okveds: %r; Invalid okveds: %r', valid_codes, okved_codes - valid_codes)
    return bool(valid_codes)


def find_entrepreneur_carriers(value, is_name_search, records_limit=10, spark_client=None):
    log.debug('Processing entrepreneur %r', value)
    spark_cm = nullcontext(spark_client) if spark_client else create_spark_client()
    with spark_cm as client:
        if is_name_search:
            entrepreneurs = client.find_entrepreneurs_by_name(value)
        else:
            entrepreneur = client.find_entrepreneur_by_code(value)
            entrepreneurs = (entrepreneur,) if entrepreneur is not None else ()

        if not entrepreneurs:
            raise CarrierSearchError('Entrepreneurs not found')
        if len(entrepreneurs) >= records_limit:
            raise CarrierSearchErrorTooMany(f'Too many entrepreneur variants: {len(entrepreneurs)}')

        carriers = []
        for entrepreneur in entrepreneurs:
            report = client.get_entrepreneur_report(inn=entrepreneur.inn, ogrnip=entrepreneur.ogrnip)
            if report and _okved_predicate(report.okveds):
                carriers.append(dataclasses.replace(entrepreneur, report=report))

        if not carriers:
            raise CarrierSearchError('Carriers not found')

        return carriers


def find_company_carriers(value, is_name_search, records_limit=10, spark_client=None):
    log.debug('Processing company %r', value)
    spark_cm = nullcontext(spark_client) if spark_client else create_spark_client()
    with spark_cm as client:
        if is_name_search:
            companies = client.find_companies_by_name(value, okopf_code=OKOPF_OOO_CODE)
        else:
            company = client.find_company_by_code(value)
            companies = (company,) if company is not None and company.okopf_code == OKOPF_OOO_CODE else ()

        if not companies:
            raise CarrierSearchError('Companies not found')
        if len(companies) >= records_limit:
            raise CarrierSearchErrorTooMany(f'Too many company variants: {len(companies)}')

        carriers = []
        for company in companies:
            report = client.get_company_report(spark_id=company.id, inn=company.inn, ogrn=company.ogrn)
            if report and _okved_predicate(report.okveds):
                carriers.append(dataclasses.replace(company, report=report))

        if not carriers:
            raise CarrierSearchError('Carriers not found')

        return carriers


CARRIER_SEARCHERS = {
    CarrierType.COMPANY: find_company_carriers,
    CarrierType.ENTREPRENEUR: find_entrepreneur_carriers
}


def find_carriers(carrier_type, value, **kwargs):
    return CARRIER_SEARCHERS[carrier_type](value, **kwargs)
