import aioboto3
import logging
import openpyxl
from aiobotocore.config import AioConfig
from collections import defaultdict
from datetime import timezone
from io import BytesIO
from sqlalchemy import and_, or_

from crm.agency_cabinet.common.server.common.config import MdsConfig
from crm.agency_cabinet.common.server.common.structs import TaskStatuses
from crm.agency_cabinet.ord.common.consts import ContractType, ContractActionType, ContractSubjectType, OrganizationType
from crm.agency_cabinet.ord.server.src.celery.tasks.transfer.format import IsVatConverter
from crm.agency_cabinet.ord.server.src.db import db, models
from crm.agency_cabinet.ord.server.src.db.queries import build_get_client_rows_query

LOGGER = logging.getLogger('celery.import_reports.base')


class ImportReportException(Exception):
    pass


class ValidateReportHeadersException(ImportReportException):
    pass


class BaseReportImporter:
    HEADERS = {}

    def __init__(self, mds_cfg: MdsConfig, report_mds_config, task_id: int, report_id: int, partner_id: int,
                 mds_filename: str):
        self._task_id = task_id
        self._report_id = report_id
        self._partner_id = partner_id
        self._bucket = report_mds_config['bucket']
        self._mds_cfg = mds_cfg
        self._prefix = report_mds_config['prefix']
        self._mds_filename = mds_filename

        self.indexes = [0]
        i = 0
        for group_headers in self.HEADERS.values():
            for headers in group_headers:
                i += len(headers)
                self.indexes.append(i)

    async def import_report(self):
        try:
            await self._update_task_status(TaskStatuses.in_progress.value)
            await self._load_mds_file()
        except Exception as ex:
            LOGGER.exception('Error during report import: %s', ex)
            await self._update_task_status(TaskStatuses.error.value)
            raise ex

        await self._update_task_status(TaskStatuses.ready.value)

    async def _load_mds_file(self):
        boto3_session = aioboto3.Session(
            aws_access_key_id=self._mds_cfg['access_key_id'],
            aws_secret_access_key=self._mds_cfg['secret_access_key']
        )

        async with boto3_session.resource('s3', endpoint_url=self._mds_cfg['endpoint_url'],
                                          config=AioConfig(s3={'addressing_style': 'virtual'})) as s3:
            object = await s3.Object(self._bucket, self._mds_filename)
            data = await object.get()
            body = await data['Body'].read()
            await self._read_file(body)

    async def _read_file(self, body):
        workbook = openpyxl.load_workbook(filename=BytesIO(body))
        worksheet = workbook.active

        row_headers = []
        for row in worksheet.iter_rows(min_row=1, max_row=2):
            row_headers.append([c.value for c in row if c.value is not None])

        if not self._validate_headers(row_headers):
            raise ValidateReportHeadersException()

        async with db.transaction(read_only=False, reuse=False):
            for row in worksheet.iter_rows(min_row=3):
                row_data = [c.value for c in row]
                await self._load_row(row_data)

            await self.after_load()

    def _validate_headers(self, row_headers):
        check_common_headers = (row_headers[0] == list(self.HEADERS.keys()))

        if not check_common_headers:
            return False

        return (row_headers[1] == [
            header for _, group_headers in self.HEADERS.items() for headers in group_headers for header in headers
        ])

    def _get_sliced_data(self, row):
        slices = []
        for i in range(len(self.indexes) - 1):
            slices.append(row[self.indexes[i] : self.indexes[i + 1]])
        return slices

    async def _load_row(row):
        raise NotImplementedError

    async def after_load(self):
        raise NotImplementedError

    async def init(self):
        query = build_get_client_rows_query(self._report_id)
        rows = await query.gino.all()

        # move to another place
        self.new_client_rows = defaultdict(list)

        self.campaigns = {}
        self.clients = {}
        self.ad_distributor_contracts = {}
        self.ad_distributor_acts = {}
        self.ad_distributor_partner_orgs = {}
        self.partner_contracts = {}
        self.partner_acts = {}
        self.partner_client_orgs = {}
        self.advertiser_contracts = {}
        self.advertiser_orgs = {}
        self.advertiser_contractor_orgs = {}
        self.client_rows = defaultdict(list)

        for row in rows:
            self.campaigns[row.campaign_eid] = {
                'id': row.campaign_id,
                'suggested_amount': row.suggested_amount
            }

            self.clients[row.client_eid] = {
                'id': row.client_id
            }

            self.ad_distributor_acts[(row.campaign_eid, row.ad_distributor_act_eid)] = {
                'id': row.ad_distributor_act_id
            }

            self.ad_distributor_contracts[row.ad_distributor_act_id] = {
                'id': row.ad_distributor_contract_id
            }

            key = row.ad_distributor_partner_org_inn or row.ad_distributor_partner_org_mobile_phone
            if key:
                self.ad_distributor_partner_orgs[key] = {
                    'id': row.ad_distributor_partner_org_id,
                    'inn': row.ad_distributor_partner_org_inn,
                    'name': row.ad_distributor_partner_org_name,
                    'mobile_phone': row.ad_distributor_partner_org_mobile_phone,
                    'epay_number': row.ad_distributor_partner_org_epay_number,
                    'reg_number': row.ad_distributor_partner_org_reg_number,
                    'alter_inn': row.ad_distributor_partner_org_alter_inn,
                    'oksm_number': row.ad_distributor_partner_org_oksm_number,
                    'type': row.ad_distributor_partner_org_type
                }

            if row.partner_contract_eid:
                self.partner_contracts[row.partner_contract_eid] = {
                    'id': row.partner_contract_id,
                    'client_id': row.partner_client_org_id,
                    'contractor_id': row.ad_distributor_partner_org_id,
                    'contract_eid': row.partner_contract_eid,
                    'type': row.partner_contract_type,
                    'action_type': row.partner_contract_action_type,
                    'subject_type': row.partner_contract_subject_type,
                    'date': row.partner_contract_date,
                    'amount': row.partner_contract_amount,
                    'is_vat': row.partner_contract_is_vat
                }

            if row.partner_act_eid:
                self.partner_acts[row.partner_act_eid] = {
                    'id':  row.partner_act_id,
                    'act_eid': row.partner_act_eid,
                    'amount':  row.partner_act_amount,
                    'is_vat':  row.partner_act_is_vat
                }

            key = row.partner_client_org_inn or row.partner_client_org_mobile_phone
            if key:
                self.partner_client_orgs[row.partner_client_org_inn] = {
                    'id': row.partner_client_org_id,
                    'inn': row.partner_client_org_inn,
                    'name': row.partner_client_org_name,
                    'mobile_phone': row.partner_client_org_mobile_phone,
                    'epay_number': row.partner_client_org_epay_number,
                    'reg_number': row.partner_client_org_reg_number,
                    'alter_inn': row.partner_client_org_alter_inn,
                    'oksm_number': row.partner_client_org_oksm_number,
                    'type': row.partner_client_org_type
                }

            if row.advertiser_contract_eid:
                self.advertiser_contracts[row.advertiser_contract_eid] = {
                    'id':  row.advertiser_contract_id,
                    'client_id': row.advertiser_org_id,
                    'contractor_id': row.advertiser_contractor_org_id,
                    'contract_eid': row.advertiser_contract_eid,
                    'type': row.advertiser_contract_type,
                    'action_type': row.advertiser_contract_action_type,
                    'subject_type': row.advertiser_contract_subject_type,
                    'date': row.advertiser_contract_date,
                    'amount': row.advertiser_contract_amount,
                    'is_vat': row.advertiser_contract_is_vat
                }

            key = row.advertiser_org_inn or row.advertiser_org_mobile_phone
            if key:
                self.advertiser_orgs[row.advertiser_org_inn] = {
                    'id': row.advertiser_org_id,
                    'inn': row.advertiser_org_inn,
                    'name': row.advertiser_org_name,
                    'mobile_phone': row.advertiser_org_mobile_phone,
                    'epay_number': row.advertiser_org_epay_number,
                    'reg_number': row.advertiser_org_reg_number,
                    'alter_inn': row.advertiser_org_alter_inn,
                    'oksm_number': row.advertiser_org_oksm_number,
                    'type': row.advertiser_org_type
                }

            key = row.advertiser_contractor_org_inn or row.advertiser_contractor_org_mobile_phone
            if key:
                self.advertiser_contractor_orgs[row.advertiser_contractor_org_inn] = {
                    'id': row.advertiser_contractor_org_id,
                    'inn': row.advertiser_contractor_org_inn,
                    'name': row.advertiser_contractor_org_name,
                    'mobile_phone': row.advertiser_contractor_org_mobile_phone,
                    'epay_number': row.advertiser_contractor_org_epay_number,
                    'reg_number': row.advertiser_contractor_org_reg_number,
                    'alter_inn': row.advertiser_contractor_org_alter_inn,
                    'oksm_number': row.advertiser_contractor_org_oksm_number,
                    'type': row.advertiser_contractor_org_type
                }

            self.client_rows[(row.campaign_eid, row.ad_distributor_act_eid)].append({
                'id': row.id,
                'ad_distributor_act_id': row.ad_distributor_act_id,
                'ad_distributor_contract_id': row.ad_distributor_contract_id,
                'partner_contract_id': row.partner_contract_id,
                'advertiser_contract_id': row.advertiser_contract_id,
                'partner_act_id': row.partner_act_id
            })

    async def _update_or_create_organization(self, row_data: list, orgs: dict):
        if len([x for x in row_data if x is None]) == len(row_data):
            return None

        inn, org_type, name, mobile_phone, epay_number, reg_number, alter_inn, oksm_number = \
            map(lambda x: str(x) if x else None, row_data)

        if org_type not in [i.value for i in OrganizationType]:
            org_type = None

        new_params = {
            'inn': inn,
            'name': name,
            'mobile_phone': mobile_phone,
            'epay_number': epay_number,
            'reg_number': reg_number,
            'alter_inn': alter_inn,
            'oksm_number': oksm_number,
            'type': org_type
        }

        key = inn or mobile_phone
        if key not in orgs:
            org = None
            if key is not None:
                org = await models.Organization.query.where(
                    and_(
                        or_(
                            models.Organization.inn == inn,
                            models.Organization.mobile_phone == mobile_phone,
                        ),
                        models.Organization.partner_id == self._partner_id
                    )
                ).gino.first()

            if not org:
                org = await models.Organization.create(partner_id=self._partner_id, **new_params)
            else:
                updated_params = {}
                for k, v in new_params.items():
                    if getattr(org, k) != v:
                        updated_params[k] = v

                if updated_params:
                    await org.update(**updated_params).apply()

            new_params.update({'id': org.id})
            orgs[key] = new_params

        else:
            params = orgs[key]
            updated_params = {}
            for k, v in new_params.items():
                if params[k] != v:
                    updated_params[k] = v

            if updated_params:
                await models.Organization.update.values(**updated_params).where(
                    models.Organization.id == params['id']
                ).gino.status(read_only=False, reuse=False)
                new_params.update({'id': params['id']})
                orgs[key] = new_params

        return orgs[key]['id']

    async def _update_or_create_contract(self, client_id, contractor_id, row_data: list, contracts: dict):
        if len([x for x in row_data if x is None]) == len(row_data) and client_id is None and contractor_id is None:
            return None

        contract_eid, contract_type, action_type, subject_type, date, amount, is_vat = row_data
        contract_eid = str(contract_eid) if contract_eid else None

        if contract_type not in [i.value for i in ContractType]:
            contract_type = None

        if action_type not in [i.value for i in ContractActionType]:
            action_type = None

        if subject_type not in [i.value for i in ContractSubjectType]:
            subject_type = None

        date = date.replace(tzinfo=timezone.utc) if date else None

        new_params = {
            'client_id': client_id,
            'contractor_id': contractor_id,
            'contract_eid': contract_eid,
            'type': contract_type,
            'action_type': action_type,
            'subject_type': subject_type,
            'date': date,
            'amount': amount,
            'is_vat': IsVatConverter.from_str(is_vat)
        }

        if contract_eid not in contracts:
            contract = None
            if contract_eid is not None:
                contract = await models.Contract.query.where(
                    and_(
                        models.Contract.contract_eid == contract_eid,
                        models.Contract.partner_id == self._partner_id
                    )
                ).gino.first()

            if not contract:
                contract = await models.Contract.create(
                    partner_id=self._partner_id,
                    **new_params
                )
            else:
                updated_params = {}
                for k, v in new_params.items():
                    if getattr(contract, k) != v:
                        updated_params[k] = v

                if updated_params:
                    await contract.update(**updated_params).apply()

            new_params.update({'id': contract.id})
            contracts[contract_eid] = new_params

        else:
            params = contracts[contract_eid]

            updated_params = {}
            for k, v in new_params.items():
                if params[k] != v:
                    updated_params[k] = v

            if updated_params:
                await models.Contract.update.values(**updated_params).where(
                    models.Contract.id == params['id']
                ).gino.status(read_only=False, reuse=False)
                new_params.update({'id': params['id']})
                contracts[contract_eid] = new_params

        return contracts[contract_eid]['id']

    async def _update_or_create_act(self, row_data: list, acts: dict):
        if len([x for x in row_data if x is None]) == len(row_data):
            return None

        act_eid, amount, is_vat = row_data
        act_eid = str(act_eid) if act_eid else None

        new_params = {
            'act_eid': act_eid,
            'amount': amount,
            'is_vat': IsVatConverter.from_str(is_vat)
        }

        if act_eid not in acts:
            act = None
            if act_eid is not None:
                act = await models.Act.query.where(
                    and_(
                        models.Act.act_eid == act_eid,
                        models.Act.report_id == self._report_id
                    )
                ).gino.first()

            if not act:
                act = await models.Act.create(
                    report_id=self._report_id,
                    **new_params
                )
            else:
                updated_params = {}
                for k, v in new_params.items():
                    if getattr(act, k) != v:
                        updated_params[k] = v

                if updated_params:
                    await act.update(**updated_params).apply()

            new_params.update({'id': act.id})
            acts[act_eid] = new_params

        else:
            params = acts[act_eid]

            updated_params = {}
            for k, v in new_params.items():
                if params[k] != v:
                    updated_params[k] = v

            if updated_params:
                await models.Act.update.values(**updated_params).where(
                    models.Act.id == params['id']
                ).gino.status(read_only=False, reuse=False)
                new_params.update({'id': params['id']})
                acts[act_eid] = new_params

        return acts[act_eid]['id']

    async def _delete_client_rows(self, campaign_id, ad_distributor_act_id):
        await models.ClientRow.delete.where(
            and_(
                models.ClientRow.campaign_id == campaign_id,
                models.ClientRow.ad_distributor_act_id == ad_distributor_act_id
            )
        ).gino.status()

    async def _create_client_row(self,
                                 client_id,
                                 campaign_id,
                                 suggested_amount,
                                 ad_distributor_act_id,
                                 ad_distributor_contract_id,
                                 partner_contract_id,
                                 advertiser_contract_id,
                                 partner_act_id):
        await models.ClientRow.create(
            client_id=client_id,
            campaign_id=campaign_id,
            suggested_amount=suggested_amount,
            ad_distributor_act_id=ad_distributor_act_id,
            ad_distributor_contract_id=ad_distributor_contract_id,
            partner_contract_id=partner_contract_id,
            advertiser_contract_id=advertiser_contract_id,
            partner_act_id=partner_act_id
        )

    async def _update_task_status(self, status: str):
        await models.ReportImportInfo.update.values(
            status=status
        ).where(
            models.ReportImportInfo.id == self._task_id
        ).gino.status()
