import os
import pathlib
import sqlite3
from typing import Iterable, Dict

import click
from infra.dctl.src import consts
from prompt_toolkit import prompt
from prompt_toolkit import PromptSession
from prompt_toolkit.history import FileHistory
from prompt_toolkit.completion import CompleteEvent, Completion
from prompt_toolkit.completion import NestedCompleter
from prompt_toolkit.completion.word_completer import WordCompleter
from prompt_toolkit.document import Document

BOXES_IGNORE = {'logbroker_tools_box', 'tvm_box'}
FOLDER = '.sshator'

motd = """╔═╗╔═╗╦ ╦╔═╗╔╦╗╔═╗╦═╗
╚═╗╚═╗╠═╣╠═╣ ║ ║ ║╠╦╝
╚═╝╚═╝╩ ╩╩ ╩ ╩ ╚═╝╩╚═

Use: <project> <stage>.<deploy_unit>"""


class SQLClient:
    def __init__(self, path: str):
        self.path = path
        self.con = sqlite3.connect(path)

    def close_connection(self):
        self.con.close()

    def check_table_exists(self):
        cur = self.con.cursor()

        cur.execute("""SELECT name FROM sqlite_master WHERE type='table' AND name='completion'""")
        exists = len(cur.fetchall()) == 1
        cur.close()

        return exists

    def create_table(self):
        cur = self.con.cursor()

        cur.execute(
            '''CREATE TABLE IF NOT EXISTS completion(project text, stage text, deploy_unit text, cluster text)'''
        )
        self.con.commit()
        cur.close()

    def drop_table(self):
        cur = self.con.cursor()

        cur.execute('''DROP TABLE IF EXISTS completion''')
        self.con.commit()
        cur.close()

    def store_completion(self, completion: Dict):
        self.drop_table()
        self.create_table()

        cur = self.con.cursor()

        for project in completion:
            for stage in completion[project]:
                for deploy_unit in completion[project][stage]:
                    for cluster in completion[project][stage][deploy_unit]:
                        cur.execute(
                            '''INSERT INTO completion(project, stage, deploy_unit, cluster) VALUES(?, ?, ?, ?)''',
                            (project, stage, deploy_unit, cluster),
                        )
        self.con.commit()
        cur.close()

    def get_completion(self):
        cur = self.con.cursor()

        cur.execute('''SELECT * FROM completion''')
        res = cur.fetchall()
        temp = {}
        for row in res:
            project = row[0]
            stage = row[1]
            deploy_unit = row[2]
            cluster = row[3]

            temp[project] = temp.get(project, {})
            temp[project][stage] = temp[project].get(stage, {})
            temp[project][stage][deploy_unit] = temp[project][stage].get(deploy_unit, set())
            temp[project][stage][deploy_unit].add(cluster)
        cur.close()

        return temp


class ConsequentCompleter(NestedCompleter):
    def get_completions(self, document: Document, complete_event: CompleteEvent, word: int = 1) -> Iterable[Completion]:
        text = document.text_before_cursor.lstrip()
        stripped_len = len(document.text_before_cursor) - len(text)
        sep = ' '
        if word == 2:
            sep = '.'

        if sep in text:
            first_term = text.split(sep)[0]
            completer = self.options.get(first_term)
            if completer is not None:
                remaining_text = text[len(first_term) :].lstrip(sep)
                move_cursor = len(text) - len(remaining_text) + stripped_len

                new_document = Document(
                    remaining_text,
                    cursor_position=document.cursor_position - move_cursor,
                )

                for c in completer.get_completions(new_document, complete_event, word + 1):
                    yield c

        # No separator in the input: behave exactly like `WordCompleter`.
        else:
            completer = WordCompleter(list(self.options.keys()), ignore_case=self.ignore_case, WORD=True)
            for c in completer.get_completions(document, complete_event):
                yield c


class DctlClient:
    def __init__(self, dctl_token_path):
        self._ctx = None
        self._user = None
        self._data_model = None
        self.dctl_token_path = dctl_token_path

    def get_ctx(self):
        if not self._ctx:
            from infra.dctl.src.cmd.dctl_cli import DctlContext

            self._ctx = DctlContext(
                dctl_token_path=self.dctl_token_path,
                yp_urls={c.name.lower(): c.address for c in consts.CLUSTER_CONFIGS.values()},
                vault_host=consts.VAULT_HOST,
            )
        return self._ctx

    def get_user(self):
        if not self._user:
            from infra.dctl.src.cmd.dctl_cli import cliutil

            self._user = cliutil.get_user()
        return self._user

    def get_data_model(self):
        if not self._data_model:
            import yp.data_model as data_model

            self._data_model = data_model
        return self._data_model

    @classmethod
    def get_clusters(cls, stage, deploy_unit):
        cluster_names = [
            x.cluster for x in stage.spec.deploy_units[deploy_unit].multi_cluster_replica_set.replica_set.clusters
        ]
        cluster_names.extend(
            [x for x in stage.spec.deploy_units[deploy_unit].replica_set.per_cluster_settings.keys()]
        )  # non-multi

        return set(cluster_names)

    def get_available_pod(self, project_name, stage_name, deploy_unit, completion_dict, cluster_name=None):
        if cluster_name:
            cluster_names = [cluster_name]
        else:
            cluster_names = completion_dict[project_name][stage_name][deploy_unit]

        q = '[/meta/pod_set_id] = "{}"'.format(f'{stage_name}.{deploy_unit}')

        for c in cluster_names:
            client = self.get_ctx().get_client(c)
            pods = client.list(self.get_data_model().OT_POD, user=self.get_user(), query=q, limit=100)

            for p in pods:
                # state = data_model.EPodCurrentState.Name(p.status.agent.state)
                # if state.startswith('PCS_'):
                #     state = state[4:]

                # if state != 'STARTED':
                #     continue
                if p.status.agent.state != 200:
                    continue

                boxes = [w.id for w in p.status.agent.pod_agent_payload.status.boxes if w.id not in BOXES_IGNORE]

                ssh_root = f'root@{boxes[0]}.{p.meta.id}.{c}.yp-c.yandex.net'

                return ssh_root

    def get_completion(self):
        client = self.get_ctx().get_client(consts.XDC_PRODUCTION_CLUSTER)
        available_stages = client.list(self.get_data_model().OT_STAGE, user=self.get_user(), limit=1000)  # todo 1000
        stages = {
            el.meta.id: {key: self.get_clusters(el, key) for key in el.spec.deploy_units.keys()}
            for el in available_stages
        }

        ans = {}
        for el in available_stages:
            ans[el.meta.project_id] = ans.get(el.meta.project_id, {})
            ans[el.meta.project_id][el.meta.id] = stages[el.meta.id]

        return ans

    def get_projects(self):
        client = self.get_ctx().get_client(consts.XDC_PRODUCTION_CLUSTER)
        available_stages = client.list(self.get_data_model().OT_STAGE, user=self.get_user(), limit=100)
        return {el.meta.project_id for el in available_stages}


def get_target(session, dctl, completion_dict):
    while True:
        text = session.prompt('> ')
        try:
            project, st_du, cluster = text.split()
        except ValueError:
            try:
                cluster = None
                project, st_du = text.split()
            except ValueError:
                click.echo(
                    click.style(
                        'ERROR: incorrect input, "project stage.deploy_unit cluster(optional)" estimated', fg='red'
                    )
                )
                continue

        try:
            st, du = st_du.split('.')
        except ValueError:
            click.echo(click.style('ERROR: incorrect stage and deploy unit combination', fg='red'))
            continue

        try:
            ssh_root = dctl.get_available_pod(project, st, du, completion_dict, cluster)
        except KeyError:
            click.echo(click.style('ERROR: incorrect pod configuration were given', fg='red'))
            continue

        return project, st, du, ssh_root


@click.group(invoke_without_command=True)
@click.option('--update', required=False, is_flag=True, help='update list of available pods')
def main(update):
    click.echo(click.style(motd, fg='cyan'))
    folder = f'{pathlib.Path.home()}/{FOLDER}'
    pathlib.Path(folder).mkdir(parents=True, exist_ok=True)

    sql = SQLClient(path=f'{folder}/temp_ssh.db')
    history = FileHistory(f'{folder}/.history')

    dctl = DctlClient(dctl_token_path=f'{pathlib.Path.home()}/.dctl/token')

    if sql.check_table_exists() and not update:
        completion_dict = sql.get_completion()
    else:
        completion_dict = dctl.get_completion()
        sql.store_completion(completion_dict)
        click.echo(click.style('list of available pods updated', fg='green'))

    completer = ConsequentCompleter.from_nested_dict(completion_dict)

    session = PromptSession(history=history, completer=completer)
    project, st, du, ssh_root = get_target(session, dctl, completion_dict)

    if 'production' in st:
        click.echo(click.style('!' * 80, fg='red'))
        click.echo(click.style('WARNING: this is production stage', fg='red'))
        click.echo(click.style('!' * 80, fg='red'))

    while True:
        error_file = 'errors.tmp'
        os.system(f'> {folder}/{error_file}')  # clear .tmp file
        cmd = f'ssh {ssh_root} -E {folder}/{error_file}'  # write error to this file
        os.system(cmd)

        if os.popen(f'grep "host key have changed" {folder}/{error_file}').read():  # if err file contains this
            decision = 'yesno'
            while decision not in {'yes', 'no', ''}:
                click.echo(click.style('WARNING: host key changed, remove it automatically ([yes]/no)?', fg='red'))
                decision = prompt('> ')
                if decision == 'yes' or decision == '':
                    cmd = f'ssh-keygen -f "{pathlib.Path.home()}/.ssh/known_hosts" -R "{ssh_root[5:]}"'  # 5 = root@
                    os.system(cmd)
                    click.echo(click.style('reconnecting!', fg='green'))
                    break
                else:
                    click.echo(click.style('cancelled!', fg='green'))
                    return

        else:
            return
