import asyncio
import functools
import logging
import re
import typing as tp

import attr
import uvloop
import yappi
from asyncio_pool import AioPool
from bs4 import Tag, BeautifulSoup as BS4
from json2html import json2html
from traitlets.config import Application, Config, get_config

from sandbox import sdk2
from sandbox.common.collections import AttrDict
from sandbox.sdk2.helpers import ProgressMeter
from sandbox.common.errors import TaskFailure

from jupytercloud.backend.lib.util.metrics.tornado_client import CleanArgsPatcher
from jupytercloud.backend.lib.clients.qyp import QYPClient
from jupytercloud.backend.lib.clients.salt import SaltClient
from jupytercloud.backend.lib.clients.salt.minion import SaltMinion
from jupytercloud.backend.lib.qyp.vm import QYPVirtualMachine
from jupytercloud.backend.lib.util.report import NullReport, Report, ReportBase
from jupytercloud.backend.lib.util.format import pretty_json

ENVIRONMENTS = {
    'testing': {
        'vm_name_prefix': 'testing-jupyter-cloud-',
        'vm_short_name_prefix': 'tjc-',
        'oauth_id': '82ede8f30a9347379538bc0877730928',
    },
    'production': {
        'vm_name_prefix': 'jupyter-cloud-',
        'vm_short_name_prefix': 'jc-',
        'oauth_id': '0bdc337f290748b982b1e8bc0c345cca',
    },
}

TABLE_STYLE = '''<style>
    summary {cursor: pointer}
    tbody>tr:hover {background-color: #f5f5f5}
    table {width: 100%; border-collapse: collapse}
    table, td, th, li, ul {margin-top: 0; padding: 5px;
        vertical-align: top; list-style: circle inside}
    h2, h3, h4 {display: inline; font-weight: 400}
    h3 > b {font-size: 120%}
</style>'''

VM = re.compile(
    r'(?P<testing>testing-|t)?'
    r'(?P<prefix>jupyter-cloud-|jc-)'
    r'(?P<login>[a-z0-9-]+)'
    r'\.(?P<dc>\w{3,8})'
    r'\.yp-c\.yandex\.net'
)

CLUSTERS = {
    'sas': 'https://vmproxy.sas-swat.yandex-team.ru/',
    'vla': 'https://vmproxy.vla-swat.yandex-team.ru/',
    'iva': 'https://vmproxy.iva-swat.yandex-team.ru/',
    'myt': 'https://vmproxy.myt-swat.yandex-team.ru/',
}

STEPS = 0
GLOBAL_METER = None


def global_progress(title, repetitions=1):
    """Add to functions to track global progress over Sandbox Job
    ```
        class SandboxJob:
            @global_progress('Step 1')
            def f():
                ...

            @global_progress('Step 2')
            def g():
                ...
    """

    def decorator(func):
        global STEPS
        STEPS += repetitions

        @functools.wraps(func)
        def wrapper_decorator(*args, **kwargs):
            global GLOBAL_METER, STEPS
            GLOBAL_METER.message = title
            logging.info('Sandbox Job step %d/%d: %s', GLOBAL_METER.value + 1, STEPS, title)

            value = func(*args, **kwargs)

            GLOBAL_METER.add(repetitions)
            return value

        return wrapper_decorator

    return decorator


@attr.define()
class SaltJobLib:
    job: sdk2.Task = None

    def __getattr__(self, item):
        if self.job is None:
            raise AttributeError
        return getattr(self.job, item)

    app: Application = None
    c: Config = None
    log: logging.Logger = logging.getLogger('JCSalt')
    clients: AttrDict = attr.Factory(AttrDict)
    meters: AttrDict = attr.Factory(AttrDict)
    report: ReportBase = NullReport
    connect_pool: AioPool = None

    def execute(self):
        asyncio.run(self.execute_async())

    async def execute_async(self):
        with ProgressMeter('Total task progress', minval=0, maxval=STEPS) as total:
            global GLOBAL_METER
            GLOBAL_METER = total

            self.setup()
            await self.restart_minions()
            minions = await self.refresh_minions()
            await self.apply_state(minions)
            await self.print_report()

            if self.Context.result.get('failure', []) or self.Context.result.get('bad-state', []):
                raise TaskFailure('Some minions failed to apply state')

    @global_progress('Setting up the environment')
    def setup(self):
        CleanArgsPatcher.patch()
        uvloop.install()

        if self.Parameters.profile:
            yappi.set_clock_type('wall')
            yappi.start()
            self.log.info('Started profiler')

        logging.getLogger('tornado.curl_httpclient').setLevel(logging.INFO)
        self.log.warning('Hiding DEBUG curl logs as they are spammy')

        c = get_config()
        env = self.Parameters.environment

        c.QYPClient.vm_name_prefix = ENVIRONMENTS[env]['vm_name_prefix']
        c.QYPClient.vm_short_name_prefix = ENVIRONMENTS[env]['vm_short_name_prefix']
        c.QYPClient.oauth_token = self.secrets['qyp_oauth']
        c.QYPClient.clusters = CLUSTERS
        c.QYPClient.use_pycurl = False

        c.JupyterCloudOAuth.client_id = ENVIRONMENTS[env]['oauth_id']

        c.SaltClient.urls = self.Parameters.salt_masters
        c.SaltClient.username = 'JupyterCloudSalt'
        c.SaltClient.password = self.secrets['salt_secret']
        c.SaltClient.eauth = 'sharedsecret'
        c.SaltClient.log_dir = self.log_path()
        c.SaltConnection.retry_stop_timeout = 600

        c.AsyncHTTPClientMixin.use_pycurl = False

        c.SSHClient.id_rsa = self.secrets['minion_ssh_key']
        self.c = c

        # setting log_level to DEBUG from the beginning dumps secrets to the log
        self.app = Application(config=self.c, log=self.log, log_level=logging.INFO)
        self.app.launch_instance()
        self.app.log_level = logging.DEBUG

        self.report = Report.instance(parent=self.app)
        self.clients.salt = SaltClient.instance(parent=self.app, config=c)
        self.clients.salt_minion = SaltMinion(parent=self.app, config=c)
        self.clients.qyp = QYPClient.instance(parent=self.app, config=c)

        self.connect_pool = AioPool(self.Parameters.concurrency)

        self.log.info('Setup successful')

    @global_progress('Restarting minions')
    async def restart_minions(self):
        if not self.Parameters.do_restart_minions:
            self.log.info('Not restarting minions')
            return

        alive_minions = await self.clients.salt_minion.get_all_alive(
            users=self.Parameters.users or None
        )
        self.log.info('%d minions initially alive', len(alive_minions))
        self.report.debug(
            'salt.alive',
            'Minions initially alive',
            vm=[m.minion_id for m in alive_minions],
        )

        vms_raw_info = await self.clients.qyp.get_vms_raw_info()
        vms = [QYPVirtualMachine.from_raw_info(self.clients.qyp, raw_vm) for raw_vm in vms_raw_info]
        self.log.info('%d VMs exist', len(vms))

        running = await self.connect_pool.map(lambda vm: vm.is_running(), vms)
        running_vms = [vm for (vm, r) in zip(vms, running) if r]
        self.log.info('%d VMs running', len(running_vms))

        if not self.Parameters.users:
            alive_ids = [m.minion_id for m in alive_minions]
            minions_ids_to_restart = [vm.host for vm in running_vms if vm.host not in alive_ids]
        else:
            minions_ids_to_restart = [vm.host for vm in vms if vm.login in self.Parameters.users]

        self.log.info('Restarting %d minions', len(minions_ids_to_restart))
        self.report.debug('salt.restart', 'Restarting minions', vm=minions_ids_to_restart)

        minions = [SaltMinion(minion_id=id_, parent=self.app) for id_ in minions_ids_to_restart]
        restarted = await self.clients.salt.restart_minions(minions=minions)

        for m, r in zip(minions, restarted):
            if r:
                self.report.debug('salt.restart.success', 'Success minion restart', vm=m.minion_id)
            else:
                self.report.error('salt.restart.failure', 'Failed minion restart', vm=m.minion_id)

        self.log.info(
            'Restarted %s minions: %s',
            sum(restarted),
            ", ".join((m.minion_id for m in minions)),
        )

    @global_progress('Getting fresh data about minions')
    async def refresh_minions(self) -> list[str]:
        alive_minions = await self.clients.salt_minion.get_all_alive(users=self.Parameters.users)
        minions = [m.minion_id for m in alive_minions]
        await self.clients.salt.local(minions, 'saltutil.refresh_pillar', request_all=True)

        self.log.info('Refreshed %d minions', len(minions))
        return minions

    @global_progress('Applying state')
    async def apply_state(self, minions: list[str]):
        with ProgressMeter('Minions with applied states', minval=0, maxval=len(minions)) as total:
            progress_callback = lambda batch: total.add(len(batch))

            result = await self.clients.salt.apply_state(
                minions,
                state=self.Parameters.state_name,
                interval=min(self.Parameters.concurrency * 5, 60),  # speeding up small jobs
                max_poll_interval=600,  # don't be mean to master
                timeout=self.Parameters.state_apply_timeout,
                batch_size=self.Parameters.batchsize,
                concurrency=self.Parameters.concurrency,
                batch_callback=progress_callback,
            )

        self.job.Context.result = result

    @global_progress('Making report')
    async def print_report(self):
        html_processor = SaltJobHTML(job=self)

        report = html_processor.preprocess_data()
        soup = html_processor.print_success_failure(report)

        self.set_info(str(soup), do_escape=False)
        self.dump_report()
        self.dump_profiler()

    def dump_profiler(self):
        if yappi.is_running():
            self.log.info('Saving profiler data')
            with open(self.log_path('yappi_profile.txt'), mode='w') as f:
                yappi.get_func_stats().print_all(
                    f,
                    columns={
                        0: ('name', 120),
                        1: ('ncall', 12),
                        2: ('tsub', 8),
                        3: ('ttot', 8),
                        4: ('tavg', 8),
                    },
                )

    def dump_report(self):
        self.log.info('Saving entire profile')
        with open(self.log_path('report.log'), mode='w') as f:
            f.write(pretty_json(self.report.dump()))


@attr.define()
class SaltJobHTML:
    job: sdk2.Task = None
    soup: BS4 = None

    def __getattr__(self, item):
        if self.job is None:
            raise AttributeError
        return getattr(self.job, item)

    def preprocess_data(self):
        levels = {
            i * 10: item
            for i, item in enumerate(['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'])
        } | {42: 'SUCCESS'}
        table_data = [line for line in self.report.dump() if line['level'] >= 0]

        def replace(match):
            if match['testing'] is not None:
                prefix = 'tjc'
            else:
                prefix = 'jc'
            return '.'.join((prefix, match['login'], match['dc']))

        for i, line in enumerate(table_data):
            if isinstance(table_data[i].get('vm'), SaltMinion):
                self.log.warning('SaltMinion at row:%d:%s! Fix that!', i, line)
                table_data[i]['vm'] = table_data[i]['vm'].minion_id

            if 'vm' not in table_data[i]:
                pass
            elif isinstance(table_data[i]['vm'], str):
                table_data[i]['vm'] = VM.sub(replace, table_data[i]['vm'])
            elif isinstance(table_data[i]['vm'], list):
                for j, vm in enumerate(table_data[i]['vm']):
                    table_data[i]['vm'][j] = VM.sub(replace, table_data[i]['vm'][j])
            else:
                raise AssertionError('Write only strings and lists to vm!')

            table_data[i]['level'] = levels.get(table_data[i]['level'], 'UNKNOWN')
        return table_data

    def print_success_failure(self, report: list[dict[str, str]]):
        self.soup = BS4('<p></p>')
        self._add_header()

        success, failure, info, debug = [], [], [], []
        for line in report:
            if 'salt.state.success' in line['event']:
                if success:
                    line['message'] = '—'
                success.append(line)
            elif line.get('level') == 'ERROR':
                failure.append(line)
            elif line.get('level') == 'INFO':
                info.append(line)
            else:
                debug.append(line)

        self.log.info(
            'List lengths:\n --SUCCESS: %s\n --FAILURE: %s\n --INFO: %s\n --DEBUG: %s',
            *(len(x) for x in (success, failure, info, debug)),
        )

        tables = []

        if len(failure):
            failure_summary = f'Failed to apply: {len({line["vm"] for line in failure})} minions'
            tables.append(self._print_table(failure, failure_summary, True))

        if len(success):
            success_summary = f'Successfully applied: {len(success)} minions'
            tables.append(self._print_table(success, success_summary, False))

        if tables:
            self.soup.p.extend(filter(None, tables))
        return self.soup.p

    def _wrap_with_details(self, obj: Tag, summary: str, show: bool) -> Tag:
        div = self.soup.new_tag('div')

        title = self.soup.new_tag('h4')
        title.extend(BS4(summary).p.contents)
        div.append(title)

        details = self.soup.new_tag('details')
        if show:
            details['open'] = '1'
        div.append(details)
        details.append(obj)

        div.append(self.soup.new_tag('br'))

        return div

    def _hide_lists(self, table: Tag) -> Tag:
        n = 0
        for lst in table('ul'):
            length = len(lst.contents)
            if length == 1:
                lst.li.unwrap()
                lst.unwrap()
                continue

            first_line = next(lst.li.stripped_strings)
            summary = BS4(f'<span><b>[{length} items]</b> {first_line}…</span>').span

            details = self.soup.new_tag('details')

            lst = lst.replace_with(details)
            details.append(lst)
            details.parent.insert(0, summary)

            lst.li.decompose()  # don't double-print first line
            n += 1

        self.log.info('Wrapped %s lists', n)
        return table

    def _print_table(self, table_data: tp.Iterable, summary: str, show: bool) -> tp.Optional[Tag]:
        table = BS4(json2html.convert(table_data, escape=False)).table
        if table is None:
            return None
        self.log.debug('Starting from summary+table: %s\n\n%s…', str(table)[:1000])
        self._hide_lists(table)

        details = self._wrap_with_details(table, summary, show)
        self.log.debug('Compiled table: %s…', str(details)[:1000])
        return details

    def _add_header(self):
        styles = BS4(TABLE_STYLE).style
        header = BS4(
            f'<h2>Results of applying state <i>{self.Parameters.state_name}</i></h2><br/>'
        ).h2

        self.soup.p.extend([styles, header])
