# -*- coding: utf-8 -*-
import os
import datetime
import json
import itertools
import time
import logging

from sandbox import sdk2
from sandbox.common.errors import TaskFailure, TaskError
from sandbox.common.rest import Client
from sandbox.common.config import Registry

from sandbox.sandboxsdk.task import SandboxTask
import sandbox.sandboxsdk.paths as sdk_paths


from sandbox.projects.resource_types import LM_DUMPS_LIST

from sandbox.projects.common.yabs.server.parameters import (
    BinDbList,
    FilterTable,
    BackupDate,
    Option,
    OptionsParameter,
    PrevArchiveContents,
    BsReleaseYT,
    AdditionalRestoreDescription,
    AdditionalRestoreDescriptionMD5,
)

from sandbox.projects.common.utils import get_or_default, check_if_tasks_are_ok
from sandbox.projects.common.yabs.server.db import yt_bases
import sandbox.projects.common.yabs.server.db.utils as dbutils

from sandbox.projects.YabsServerB2BPrepareSQL import __Task__ as YabsServerB2BPrepareSQL

from sandbox.projects.yabs.bases.table_backups import get_all_table_sources

from sandbox.projects.yabs.qa.resource_types import (
    YABS_MYSQL_ARCHIVE_CONTENTS,
    YABS_MYSQL_RESTORE_DESCRIPTION,
)

TASK_PATH = os.path.dirname(__file__)


class TestenvSwitchTriggerValue(sdk2.parameters.String):
    name = 'testenv_switch_trigger_value'
    description = 'Use following string as testenv switch trigger value (backup date will be used if empty)'
    default_value = ''


class WaitForBackups(sdk2.parameters.Bool):
    name = 'wait_for_backups'
    description = 'Wait for table backups for specified date to appear'
    default_value = False


class WaitForBackupsPeriod(sdk2.parameters.Integer):
    name = 'wait_for_backups_period'
    description = 'Sleep between retries in seconds'
    default_value = 15 * 60


class WaitForBackupsTimeout(sdk2.parameters.Integer):
    name = 'wait_for_backups_timeout'
    description = 'Fail task if backups aren\'t present after following period in seconds'
    default_value = 4 * 60 * 60


OUT_RES_ID_KEY = 'out_res_id'

_CHILD_TASKS_KEY = 'child_tasks_ids'
_LINKS_RES_KEY = 'links_res_id'
_SOURCES_RES_KEY = 'sources_res_id'


class YabsServerGetSQLArchive(SandboxTask):

    execution_space = 20 * 1024

    type = 'YABS_SERVER_GET_SQL_ARCHIVE'

    input_parameters = (
        PrevArchiveContents,
        AdditionalRestoreDescriptionMD5,
        AdditionalRestoreDescription,
        BsReleaseYT,
        BinDbList,
        FilterTable,
        BackupDate,
        OptionsParameter,
        TestenvSwitchTriggerValue,
        WaitForBackups,
        WaitForBackupsPeriod,
        WaitForBackupsTimeout,
    )

    def has_option(self, opt):
        return opt in self.ctx.get(OptionsParameter.name, '').split()

    def on_execute(self):
        date = self._get_backup_date()

        try:
            child_tasks_ids = self.ctx[_CHILD_TASKS_KEY]
        except KeyError:
            child_tasks_ids = self._start(date)
            self.ctx[_CHILD_TASKS_KEY] = child_tasks_ids

        cl = Client()
        RUNNING = self.Status.Group.EXECUTE + self.Status.Group.QUEUE + self.Status.Group.WAIT
        if any(cl.task[ch_id].read()['status'] in RUNNING for ch_id in child_tasks_ids):
            self.wait_tasks(child_tasks_ids, tuple(self.Status.Group.FINISH) + tuple(self.Status.Group.BREAK), wait_all=True)

        check_if_tasks_are_ok(child_tasks_ids)
        self._aggregate_results(child_tasks_ids, date)

    def _start(self, date):
        child_ids = []
        restore_descrs, links, sources = self._get_restore_descrs(date)

        self.ctx[_LINKS_RES_KEY] = self._create_rd(links, 'table_links', date)
        if sources:
            self.ctx[_SOURCES_RES_KEY] = self._create_rd(sources, 'table_sources', date)

        subtask_execution_space = 200 * 1024 if Registry().client.sandbox_user else 1024

        for backup, data in restore_descrs.iteritems():
            for n, chunk in enumerate(_chunkify_dict(data, 40)):
                res_id = self._create_rd(chunk, '{}_{}'.format(backup, n), date)

                subtask = self.create_subtask(
                    task_type=YabsServerB2BPrepareSQL.type,
                    description='Tables from {}, {}, part {}'.format(backup, date, n + 1),
                    input_parameters={
                        YabsServerB2BPrepareSQL.restore_descr_key: res_id,
                        BinDbList.name: '',
                        FilterTable.name: '.*',
                        BackupDate.name: date,
                        'options_list': 'packages_setup archive_tables',
                        'kill_timeout': 18000,  # default 3 hours is not enough sometimes!
                    },
                    execution_space=subtask_execution_space,
                )
                child_ids.append(subtask.id)

        return child_ids

    def _create_rd(self, data, what, date, mark_ready=True):
        res_dir = sdk_paths.make_folder('restore_descr_{}_{}'.format(what, date))
        path = os.path.join(res_dir, 'rd.json')
        with open(path, 'w') as rd:
            json.dump(data, rd, indent=4)

        res = self.create_resource(
            description="Restore description for {}, {}".format(date, what),
            resource_path=path,
            resource_type=YABS_MYSQL_RESTORE_DESCRIPTION,
            attributes={'date': date},
            arch='any'
        )
        if mark_ready:
            self.mark_resource_ready(res.id)
        return res.id

    def _get_backup_date(self):
        """Get backup date if not specified. Check for sanity."""
        date = self.ctx.get(BackupDate.name)
        if not date:
            date = datetime.datetime.today().strftime("%Y%m%d")
        try:
            year = (datetime.datetime.strptime(date, "%Y%m%d")).year
        except Exception:
            raise TaskFailure("Bad backup date: {}, required format is %Y%m%d".format(date))
        if year < 2016 or year > 2100:
            raise TaskFailure("Bad backup year ({}), I hope this is a typo.".format(date))
        self.ctx[BackupDate.name] = date
        return date

    def _get_restore_descrs(self, date):
        new_sources = {}

        missing_tables = self._get_needed_tables()

        prev_archive_res_id = get_or_default(self.ctx, PrevArchiveContents)
        if prev_archive_res_id:
            contents = dbutils.GluedMySQLArchiveContents(self.sync_resource, prev_archive_res_id)
            missing_tables -= contents.tables.viewkeys()
            sources = contents.sources
        else:
            rsync_results_dir = str(self.log_path('rsync'))
            os.makedirs(rsync_results_dir)
            new_sources = get_all_table_sources(date, tmp_dir=rsync_results_dir)
            sources = new_sources
            # FIXME Store this!

        restore_descrs = dict()
        links = dict()

        not_found_tables = set()

        for key in missing_tables:
            inst, db, table = key.split('.')
            try:
                backup = sources[key]
            except KeyError:
                if inst == 'yabs':
                    not_found_tables.add(key)
                    continue

                # Try replica
                bsdb_key = '.'.join(('yabs', db, table))
                try:
                    backup = sources[bsdb_key]
                except KeyError:
                    not_found_tables.add(bsdb_key)
                    continue
                links[key] = '=' + bsdb_key
                restore_descrs.setdefault(backup, dict())[bsdb_key] = backup
            else:
                restore_descrs.setdefault(backup, dict())[key] = backup

        if not_found_tables:
            logging.info('Not found tables for date %s: %s', date, not_found_tables)
            if not get_or_default(self.ctx, WaitForBackups):
                self.set_info('Not found tables for {}:\n{}'.format(date, '\n'.join(not_found_tables)))
                raise TaskError("{} tables are missing from backups for {}".format(len(not_found_tables), date))

            first_try_attempt_timestamp = self.ctx.get('first_try_attempt_timestamp')
            now = time.time()
            if first_try_attempt_timestamp is None:
                self.set_info('Not found tables for {}:\n{}'.format(date, '\n'.join(not_found_tables)))
                first_try_attempt_timestamp = now
                self.ctx['first_try_attempt_timestamp'] = first_try_attempt_timestamp
            if first_try_attempt_timestamp + get_or_default(self.ctx, WaitForBackupsTimeout) < now:
                raise TaskError("{} tables are missing from backups for {}".format(len(not_found_tables), date))
            raise self.wait_time(get_or_default(self.ctx, WaitForBackupsPeriod))

        return restore_descrs, links, new_sources

    def _get_needed_tables(self):
        tags = get_or_default(self.ctx, BinDbList).split()
        needed_tables = set('.'.join((inst, db, table)) for inst, db, table in self._iter_cs_tables(tags))

        add_rd_res = get_or_default(self.ctx, AdditionalRestoreDescription)
        if add_rd_res:
            add_rd_path = self.sync_resource(add_rd_res)
            with open(add_rd_path) as f:
                add_rd = json.load(f)
                needed_tables |= add_rd.viewkeys()

        return needed_tables

    def _iter_cs_tables(self, db_tags):
        res_id = get_or_default(self.ctx, BsReleaseYT)
        yabscs_path = dbutils.get_yabscs(self, res_id)
        return yt_bases.iter_db_tables(yabscs_path, db_tags, cs_import=True)

    def _get_lm_dumps_resources(self, lm_shards=18):
        cl = Client()

        lm_dumps = {}
        attr = {
            'released': 'stable' if lm_shards == 18 else 'testing',
            'total_shards': str(lm_shards),
        }
        for shard in xrange(lm_shards + 1):
            if shard == 0:
                attr['location'] = 'meta'
            else:
                attr['location'] = 'stat'
                attr['shard'] = '%s' % shard

            location = attr['location']
            resource_data = cl.resource.read(
                type=str(LM_DUMPS_LIST),
                attrs=attr,
                state='READY',
                limit=25,
                order='-id',
            )
            try:
                resource_id = resource_data['items'][0]['id']
            except IndexError:
                raise TaskError(
                    "Failed to find LM_DUMPS_LIST resource with location '{}' and shard '{}'".format(location, shard)
                )
            lm_dumps['lm_dumps_%s' % shard] = resource_id

        return lm_dumps

    def _aggregate_results(self, child_tasks_ids, date):
        out_res_dir = sdk_paths.make_folder('archive_contents')

        res_ids = self._get_archive_content_parts(child_tasks_ids)

        joint_data = dict(tables={}, lm_dumps={}, sizes={})
        for res_id in res_ids:
            fragment = self._get_json_res(res_id)
            for section_key in 'tables', 'sizes':
                joint_section = joint_data[section_key]
                fragment_section = fragment.get(section_key, {})
                duplicates = joint_section.viewkeys() & fragment_section.viewkeys()
                if duplicates:
                    raise TaskError("Duplicate keys in section {}: {}".format(section_key, ', '.join(duplicates)))
                joint_section.update(fragment_section)

        tables_section = joint_data['tables']

        for key, value in self._get_json_res(self.ctx[_LINKS_RES_KEY]).iteritems():
            if key in tables_section:
                raise TaskError("Link key already exists: {}".format(key))
            link_target = value[1:]
            tables_section[key] = tables_section[link_target]

        if not self.has_option(Option.NO_LM_DUMPS):
            joint_data['lm_dumps_for_shard_count'] = {
                "18": self._get_lm_dumps_resources(18),
                "12": self._get_lm_dumps_resources(12),
            }

        if _SOURCES_RES_KEY in self.ctx:
            joint_data['sources'] = self._get_json_res(self.ctx[_SOURCES_RES_KEY])

        self._create_output_resource(joint_data, date, out_res_dir)

    def _get_json_res(self, res_id):
        path = self.sync_resource(res_id)
        with open(path) as json_file:
            return json.load(json_file)

    def _get_archive_content_parts(self, child_tasks_ids):
        res_ids = []
        cl = Client()
        for ch_id in child_tasks_ids:
            ctx = cl.task[ch_id].context.read()
            if not ctx.get(YabsServerB2BPrepareSQL.restore_descr_key):
                # We didn't give table list to this child (LMs, for example)
                continue
            res_data = cl.task[ch_id].resources.read()
            res_id = None
            for item in res_data['items']:
                if item['state'] == 'READY' and item['type'] == str(YABS_MYSQL_ARCHIVE_CONTENTS):
                    res_id = item['id']
                    break
            if not res_id:
                raise TaskError("Task {} has no resources of type {} in state READY".format(ch_id, YABS_MYSQL_ARCHIVE_CONTENTS))
            res_ids.append(res_id)
        return res_ids

    def _create_output_resource(self, out_data, date, out_res_dir):
        path = os.path.join(out_res_dir, 'archive_contents.json')
        with open(path, 'w') as rd:
            json.dump(out_data, rd, indent=4)

        attrs = {'date': date}
        switch_testenv = self.has_option(Option.SWITCH_TESTENV)
        if switch_testenv:
            testenv_switch_trigger_value = get_or_default(self.ctx, TestenvSwitchTriggerValue) or date
            attrs['testenv_switch_trigger'] = testenv_switch_trigger_value

        res = self.create_resource(
            description="MySQL archive contents{}, {} ".format(' for Testenv bases' if switch_testenv else '', date),
            resource_path=path,
            resource_type=YABS_MYSQL_ARCHIVE_CONTENTS,
            attributes=attrs,
            arch='any'
        )
        self.ctx[OUT_RES_ID_KEY] = res.id

        prev_res_id = get_or_default(self.ctx, PrevArchiveContents)
        if prev_res_id:
            self.mark_resource_ready(res.id)
            dbutils.append_mysql_archive_contents(self, prev_res_id, res.id)

        return res


def _chunkify_dict(data, max_chunk_length):
    for items_chunk in itertools.izip_longest(*([data.iteritems()] * max_chunk_length)):
        yield dict(itertools.ifilter(None, items_chunk))


__Task__ = YabsServerGetSQLArchive
