import asyncio
import aiohttp
import os
import psycopg2
import ssl
from collections import defaultdict
from ylog.context import log_context

from aiopg.sa import create_engine

from intranet.vconf.src.cms_sync.settings import CMS_URL_TEMPLATE, DB_CONFIG
from intranet.vconf.src.cms_sync.log import logger
from intranet.vconf.src.cms_sync.events import CallsEvent, CallRosterEvent
from intranet.vconf.src.cms_sync.db import Gateway


ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations('/etc/allCAs.pem')


async def get_token(base_url, session):
    headers = {
        'authorization': 'Basic ' + os.getenv('CMS_AUTH_HASH', ''),
        'content-type': 'application/x-www-form-urlencoded',
    }
    params = {
        'method': 'POST',
        'url': base_url + '/api/v1/authTokens',
        'headers': headers,
        'ssl_context': ssl_context,
    }
    async with session.request(**params) as response:
        return response.headers['X-Cisco-CMS-Auth-Token']


class Handler:

    def __init__(self, session, db_engine, node):
        self.session = session
        self.db_engine = db_engine
        self.db_gateway = None
        self.node = node
        self.ws = None
        self.call_ids = []
        self.message_id = 1

    def build_events(self):
        event = CallsEvent(index=1)
        result = [event.as_dict()]

        for i, call_id in enumerate(self.call_ids, start=2):
            event = CallRosterEvent(index=i, call_id=call_id)
            result.append(event.as_dict())

        return result

    async def subscribe(self):
        events = self.build_events()
        data = {
            'type': 'message',
            'message': {
                'messageId': self.message_id,
                'type': 'subscribeRequest',
                'subscriptions': events,
            }
        }
        self.message_id += 1
        await self.ws.send_json(data)

    async def send_ack(self, message_id):
        data = {
            'type': 'messageAck',
            'messageAck': {
                'messageId': message_id,
                'status': 'success',
            }
        }
        await self.ws.send_json(data)

    def parse_response(self, data):
        if data.type is not aiohttp.WSMsgType.TEXT:
            return

        body = data.json()

        if body['type'] == 'message':
            if body['message']['type'] == 'subscriptionUpdate':
                counts = defaultdict(int)
                for subscription in body['message']['subscriptions']:
                    counts[subscription['state']] += 1
                with log_context(cms_node=self.node):
                    logger.info(
                        '%d pending and %d active subscriptions',
                        counts['pending'],
                        counts['active'],
                    )
            return body['message']

    async def process_message(self, message):
        if message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
            logger.error('Connection was unexpectedly closed: %s', message.json())
            return

        data = self.parse_response(message)
        if not data:
            return

        await self.send_ack(data['messageId'])

        if data['type'] == 'callListUpdate':
            await self._process_call_list_update(data)

        elif data['type'] == 'rosterUpdate':
            await self._process_roster_update(data)

    async def _process_call_list_update(self, data):
        for update in data['updates']:
            with log_context(cms_node=self.node):
                logger.info('callListUpdate: %s', update)

            call_id = update['call']
            if update['updateType'] == 'add':
                self.call_ids.append(call_id)

            elif update['updateType'] == 'remove':
                self.call_ids.remove(call_id)
                await self.db_gateway.end_call(call_id)

        await self.subscribe()

    async def _process_roster_update(self, data):
        for update in data['updates']:
            with log_context(cms_node=self.node):
                logger.info('rosterUpdate: %s', update)

            participant_id = update['participant']
            if update['updateType'] == 'remove':
                await self.db_gateway.disconnect_participant(participant_id)
            if update['updateType'] == 'update':
                await self.db_gateway.update_participant(participant_id, update)

    async def main(self):
        base_url = CMS_URL_TEMPLATE.format(node=self.node)
        token = await get_token(base_url, self.session)
        url = base_url + '/events/v1?authToken=' + token

        async with self.session.ws_connect(url) as ws:
            self.ws = ws
            await self.subscribe()
            async with self.db_engine.acquire() as conn:
                self.db_gateway = Gateway(conn)
                asyncio.get_event_loop().create_task(self.monitor_not_synced_calls())
                with log_context(cms_node=self.node):
                    logger.info('message listening started on node %s', self.node)
                async for message in self.ws:
                    try:
                        await self.process_message(message)
                    except psycopg2.Error:
                        with log_context(cms_node=self.node, message=message):
                            logger.exception('DB error during process message.')
                        raise
                    except Exception:
                        with log_context(cms_node=self.node, message=message):
                            logger.exception('Unknown error during process message.')

    async def monitor(self):
        while True:
            with log_context(cms_node=self.node):
                logger.info('Active calls: %d', len(self.call_ids))
            await asyncio.sleep(60)

    async def monitor_not_synced_calls(self):
        while True:
            active_calls = await self.db_gateway.get_active_calls(node=self.node)
            with log_context(cms_node=self.node):
                logger.info('not synced calls: %s', [
                    call[0] for call in active_calls
                    if call[0] not in set(self.call_ids)
                ])
            await asyncio.sleep(60)


async def get_node_ids():
    result = []
    async with create_engine(**DB_CONFIG) as db_engine:
        async with db_engine.acquire() as conn:
            async for row in conn.execute('SELECT id from call_node where enabled IS TRUE'):
                result.append(row[0])
    return result


def main():
    logger.info('CMS sync started')
    loop = asyncio.get_event_loop()
    nodes = loop.run_until_complete(get_node_ids())

    db_engine = loop.run_until_complete(create_engine(maxsize=len(nodes), **DB_CONFIG))
    session = aiohttp.ClientSession()

    tasks = []
    for node in nodes:
        handler = Handler(session=session, db_engine=db_engine, node=node)
        tasks.append(handler.main())
        tasks.append(handler.monitor())

    try:
        loop.run_until_complete(asyncio.gather(*tasks))
    finally:
        loop.close()


if __name__ == '__main__':
    main()
