from contextlib import nullcontext

from sqlalchemy.dialects import postgresql
from sqlalchemy.exc import IntegrityError

from travel.rasp.bus.library.carrier import CarrierType
from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.carrier import Carrier
from travel.rasp.bus.db.models.carrier_matching import CarrierMatching


def _make_company_update(company):
    company_report = company.report
    stmt = postgresql.insert(Carrier).values({
        'register_type_id': CarrierType.COMPANY.value,
        'register_number': company.ogrn,
        'inn': company.inn,
        'actual_address': company.address,
        'legal_name': company_report.short_name,
        'legal_address': company_report.legal_address
    }).returning(Carrier.id)

    return stmt.on_conflict_do_update(
        constraint='carrier_register_unq_index',
        set_={
            'actual_address': stmt.excluded.actual_address,
            'legal_name': stmt.excluded.legal_name,
            'legal_address': stmt.excluded.legal_address,
        },
    )


def _make_entrepreneur_update(entrepreneur):
    stmt = postgresql.insert(Carrier).values({
        'register_type_id': CarrierType.ENTREPRENEUR.value,
        'register_number': entrepreneur.ogrnip,
        'inn': entrepreneur.inn,
        'legal_name': entrepreneur.full_name,
    }).returning(Carrier.id)

    return stmt.on_conflict_do_update(
        constraint='carrier_register_unq_index',
        set_={
            'legal_name': stmt.excluded.legal_name,
        },
    )


UPDATE_FACTORIES = {
    CarrierType.COMPANY: _make_company_update,
    CarrierType.ENTREPRENEUR: _make_entrepreneur_update,
}


class CarrierMatchingError(Exception):
    pass


def create_carrier_and_matching(carrier_type, carrier_code, spark_carrier, supplier_ids, db_session=None):
    if db_session:
        db_session.begin_nested()
        session_cm = nullcontext(db_session)
    else:
        session_cm = session_scope()
    with session_cm as session:
        update_stmt = UPDATE_FACTORIES[carrier_type](spark_carrier)
        carrier_id = next(session.execute(update_stmt))[0]

        carrier_matchings = tuple(
            CarrierMatching(supplier_id=supplier_id, code=carrier_code, carrier_id=carrier_id)
            for supplier_id in supplier_ids
        )
        session.add_all(carrier_matchings)

        try:
            session.flush()
            if db_session:
                session.commit()
        except IntegrityError:
            session.rollback()
            raise CarrierMatchingError('Carrier matching conflict')

        return (carrier_id, tuple(carrier_matching.id for carrier_matching in carrier_matchings))
