from __future__ import print_function
import logging
import time
import os
import re

from collections import defaultdict
from datetime import datetime, timedelta
from multiprocessing.pool import ThreadPool

from sandbox import sdk2
from sandbox.sdk2 import parameters

from sandbox.projects.yabs.base_bin_task import BaseBinTask

DAY = 24 * 60 * 60
COMPRESSION_CODEC = 'brotli_8'
ERASURE_CODEC = 'lrc_12_2_2'


class GenerationMapper:
    def __init__(self, generations):
        self.generations = generations

    def __call__(self, row):
        row['_Generation'] = self.generations[row['@table_index'] or 0]
        yield row


def merge_yt_schemas(first, second):
    result = []
    first = {col['name']: col['type'] for col in first}
    second = {col['name']: col['type'] for col in second}

    for name in set(first.iterkeys()).intersection(set(second.iterkeys())):
        assert first[name] == second[name], 'Incompatible schemas, column {}'.format(name)
        result.append({'name': name, 'type': first[name]})

    return result


class YabsCpmMultiplierMerge(BaseBinTask):
    '''Merges CPMMultiplier intermediate tables
    '''

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 4096
        disk_space = 4096

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(BaseBinTask.Parameters):
        description = 'Merge CPMMultiplier intermediate tables'

        with BaseBinTask.Parameters.version_and_task_resource() as version_and_task_resource:
            resource_attrs = parameters.Dict('Filter resource by', default={'name': 'YabsCpmMultiplier'})

        with parameters.Group('YT parameters') as yt_params:
            clusters = parameters.List(
                'YT clusters',
                required=True,
                default=['hahn', 'arnold'],
            )

            yt_token_secret_id = parameters.YavSecret(
                label="YT token secret id",
                required=True,
                description='secret should contain keys: YT_TOKEN',
                default="sec-01d6dyn0qa3xds1mp820ssgbez",
            )

        with parameters.Group('Execution parameters') as exec_params:
            results_prefix = parameters.String(
                'Intermediate results prefix',
                required=True,
                default='//home/yabs/stat/cpm_multiplier',
            )
            out_prefix = parameters.String(
                'Merged output prefix',
                required=True,
                default='//home/yabs/stat/cpm_multiplier/merged',
            )
            period = parameters.Integer(
                'Period to store intermediate tables before merging (days)',
                required=True,
                default=10,
            )
            threads = parameters.Integer(
                'Number of concurrent merges',
                required=True,
                default=5,
            )
            compression_delay = parameters.Integer(
                'Period to compress merged tables (days)',
                required=True,
                default=60,
            )

    def get_tables_to_merge(self, ytc, limit):
        tables_to_merge = defaultdict(list)
        for generation in ytc.list(self.Parameters.results_prefix):
            if generation.isdigit() and int(generation) < limit:
                generation_path = os.path.join(self.Parameters.results_prefix, generation)
                for table in ytc.list(generation_path):
                    tables_to_merge[table].append(int(generation))

        return tables_to_merge

    def get_paths_and_schema(self, ytc, generations, table, target_dir):
        from yt.wrapper import YPath
        from yt.yson import YsonList

        paths = defaultdict(lambda: {'input': [], 'mapped': []})
        result_schema = []
        for gen in generations:
            input_path = os.path.join(self.Parameters.results_prefix, str(gen), table)

            schema = ytc.get_attribute(input_path, 'schema')
            if len(schema) == 0:
                ytc.remove(input_path)
                continue
            for elem in schema:
                if 'sort_order' in elem:
                    del elem['sort_order']

            result_schema = merge_yt_schemas(result_schema, schema)
            schema.insert(0, {'sort_order': 'ascending', 'type': 'int64', 'name': '_Generation'})
            schema.attributes['strict'] = False

            day = datetime.fromtimestamp(gen).strftime('%Y-%m-%d')
            paths_day = paths[day]
            paths_day['input'].append(input_path)
            paths_day['mapped'].append(YPath('{}/{}/{}'.format(target_dir, day, str(gen)), attributes={'schema': schema}))

        result_schema.insert(0, {'sort_order': 'ascending', 'type': 'int64', 'name': '_Generation'})
        result_schema = YsonList(result_schema)
        result_schema.attributes['strict'] = False

        return paths, result_schema

    def merge_single(self, ytc, generations, input_paths, mapped_paths, out_path, schema):
        from yt.wrapper import YsonFormat

        ytc.run_map(GenerationMapper(generations), input_paths, mapped_paths, format=YsonFormat(control_attributes_mode='row_fields'))
        if ytc.exists(out_path):
            mapped_paths.append(out_path)
        else:
            ytc.create('table', out_path, attributes={
                'schema': schema,
                'optimize_for': 'scan',
                'compression_codec': COMPRESSION_CODEC,
                'erasure_codec': ERASURE_CODEC
            })

        ytc.run_merge(mapped_paths, out_path, mode='sorted', spec={'merge_by': ['_Generation'], 'combine_chunks': True})

        for table in input_paths:
            ytc.remove(table)

    def merge_wrapper(self, cluster, table, day, paths, mapped_dir, generations, schema):
        from yt.wrapper import YtClient

        logging.info('Processing day %s with input: %s', day, paths['input'])

        ytc = YtClient(proxy=cluster, token=self.yt_token)
        out_dir = '{}/{}'.format(self.Parameters.out_prefix, day)
        ytc.create('map_node', out_dir, ignore_existing=True)

        with ytc.Transaction():
            mapped_dir = '{}/{}'.format(mapped_dir, day)
            ytc.create('map_node', mapped_dir, ignore_existing=True)

            self.merge_single(ytc, generations, paths['input'], paths['mapped'], '{}/{}'.format(out_dir, table), schema)

            ytc.remove(mapped_dir, recursive=True)

    def compress_tables(self, ytc, yql, directory_path):
        if ytc.exists(directory_path + '/@compressed'):
            return

        for table_path in ytc.list(directory_path, absolute=True):
            generation = 0
            for row in ytc.read_table(table_path):
                generation = row["_Generation"]
                if generation != 0:
                    break
            if generation == 0:
                continue

            with ytc.Transaction() as tx:
                with ytc.TempTable(directory_path) as tmp_table:
                    request = yql.query('''
                        PRAGMA yt.ForceInferSchema;
                        PRAGMA yt.ExternalTx = "{tx_id}";
                        PRAGMA yt.PublishedCompressionCodec = "{compression}";
                        PRAGMA yt.PublishedErasureCodec = "{erasure}";
                        INSERT INTO `{output}` WITH TRUNCATE
                        SELECT * FROM `{input}`
                        WHERE _Generation = {generation};
                    '''.format(
                        tx_id=tx.transaction_id,
                        compression=COMPRESSION_CODEC,
                        erasure=ERASURE_CODEC,
                        output=tmp_table,
                        input=table_path,
                        generation=generation),
                    syntax_version=1)
                    request.run()
                    logging.info('Compress {table_name} table: {yql_url}'.format(
                        table_name=os.path.basename(table_path),
                        yql_url=request.share_url))
                    request.wait_progress()
                    ytc.remove(table_path)
                    ytc.move(tmp_table, table_path)

        ytc.set(directory_path + '/@compressed', True)

    def on_execute(self):
        from yt.wrapper import YtClient
        from yql.api.v1.client import YqlClient

        yt_logger = logging.getLogger('Yt')
        yt_logger.setLevel(logging.INFO)
        self.yt_token = self.Parameters.yt_token_secret_id.data()["YT_TOKEN"]
        limit = int(time.time()) - self.Parameters.period * DAY
        logging.info('Merging tables created before %s', limit)

        pool = ThreadPool(self.Parameters.threads)

        for cluster in self.Parameters.clusters:
            ytc = YtClient(proxy=cluster, token=self.yt_token)
            yql = YqlClient(db=cluster, token=self.yt_token)

            tables_to_merge = self.get_tables_to_merge(ytc, limit)
            generations_to_delete = []
            mapped_dirs = []
            tasks = []

            for table, generations in tables_to_merge.iteritems():
                generations_to_delete += generations
                logging.info('Merging %s from generations %s', table, generations)

                mapped_dir = os.path.join(self.Parameters.out_prefix, table + '_tmp')
                mapped_dirs.append(mapped_dir)
                paths_by_day, schema = self.get_paths_and_schema(ytc, generations, table, mapped_dir)

                for day, paths in paths_by_day.iteritems():
                    tasks.append((table, day, paths, mapped_dir, generations, schema))

            for mapped_dir in mapped_dirs:
                ytc.create('map_node', mapped_dir, ignore_existing=True)

            pool.map(lambda t: self.merge_wrapper(cluster, *t), tasks)

            logging.info('Deleting empty dirs %s', generations_to_delete)
            for gen in set(generations_to_delete):
                ytc.remove(os.path.join(self.Parameters.results_prefix, str(gen)))

            for dir_name in ytc.list(self.Parameters.out_prefix):
                if re.match("[0-9]{4}-[0-9]{2}-[0-9]{2}", dir_name):
                    if datetime.strptime(dir_name, '%Y-%m-%d') + timedelta(days=self.Parameters.compression_delay) < datetime.now():
                        self.compress_tables(ytc, yql, os.path.join(self.Parameters.out_prefix, dir_name))
