# -*- encoding: utf-8 -*-

import travel.avia.admin.init_project  # noqa

import logging
import os
import sys
from datetime import datetime, timedelta
from optparse import OptionParser
from threading import Thread
from Queue import Queue

import yt.wrapper as yt
import yt.logger_config as yt_logger_config
import yt.logger as yt_logger

from django.conf import settings

from travel.avia.admin.lib import yt_helpers as yth
from travel.avia.admin.lib.logs import add_stdout_handler, create_current_file_run_log


log = logging.getLogger(__name__)


ROOT_PATH = '//home/rasp'
ALLOWED_ENVS = ['production', 'dev']
SKIP_DIRS = tuple(
    os.path.join(ROOT_PATH, d)
    for d in [
        'import', 'staging', 'price_prediction', 'dapavlov',
        'logs/avia-redir-balance-by-day-log', 'logs/avia-redir-balance-log',
        'logs/avia-redir-enriched-balance-by-day-log',
        'logs/rasp-ufs-response-log',
    ]
)

MEGABYTE = 1024 * 1024
TECHICAL_COLUMNS = {
    '_logfeller_index_bucket',
    '_logfeller_timestamp',
    '_rest',
    '_stbx',
    'source_uri',
    'timestamp',
    'iso_eventtime',
}


def get_user_attributes(table, yt):
    return {
        attribute: yt.get_attribute(table, attribute)
        for attribute in yt.get_attribute(table, 'user_attributes_keys', [])
    }


def set_attributes(table, attributes, yt):
    for attribute, value in attributes.iteritems():
        yt.set(table + '/@{}'.format(attribute), value)


def copy_user_attributes(src_table, dst_table, yt):
    if yt.exists(src_table + '/@user_attribute_keys'):
        for attribute in yt.get(src_table + '/@user_attribute_keys'):
            attribute_value = yt.get(src_table + '/@{}'.format(attribute))
            yt.set(dst_table + '/@{}'.format(attribute), attribute_value)


def create_schema(table, yt):
    schema = yt.get(table + '/@schema')
    if len(schema) > 0:
        return schema

    if yt.exists(table + '/@_read_schema'):
        schema = yt.get(table + '/@_read_schema')
        unknown_technical_columns = TECHICAL_COLUMNS - set(
            column['name'] for column in schema
        )

        schema.extend([{
            'name': techincal_column,
            'type': 'any',
        } for techincal_column in unknown_technical_columns])

        return schema

    return None


def get_tables():
    tables = []
    attributes = ["resource_usage", "row_count", "erasure_codec", "compressed_data_size"]

    for table in yt.search(ROOT_PATH, node_type=['table'], attributes=attributes):
        resource_usage = table.attributes['resource_usage']
        erasure_codec = table.attributes['erasure_codec']
        compressed_data_size = table.attributes['compressed_data_size']
        chunk_count = resource_usage['chunk_count']

        # Skip 3 days logs
        if table.startswith('//home/rasp/logs/'):
            try:
                table_date = datetime.strptime(table.split('/')[-1], '%Y-%m-%d').date()
            except ValueError:
                continue

            if table_date >= datetime.now().date() - timedelta(days=2):
                continue

        # Skip 3 days logs
        if table.startswith(SKIP_DIRS):
            continue

        must_be_remerge = any([
            erasure_codec != 'lrc_12_2_2',
            chunk_count > 1 and (compressed_data_size / chunk_count) < 256 * MEGABYTE
        ])

        if must_be_remerge:
            tables.append([table, chunk_count])

    return sorted(tables, key=lambda x: x[1], reverse=True)


def remerge_table(yt, table, chunk_count):
    desired_chunk_size = 2 * 1024 ** 3
    ratio = yt.get(table + "/@compression_ratio")
    data_size_per_job = max(1, int(desired_chunk_size / ratio) if ratio > 0 else 0)

    schema = create_schema(table, yt)
    attributes = {
        'compression_codec': 'zlib_9',
        'erasure_codec': 'lrc_12_2_2',
    }
    if schema is not None:
        attributes.update({
            'optimize_for': 'scan',
            'schema': schema,
        })

    tmp_table = yth.create_safe_temp_table(
        __file__, attributes=attributes, ytc=yt,
    )

    user_attributes = get_user_attributes(table, yt)
    log.info('Merge %s to %s', table, tmp_table)
    yt.run_merge(
        source_table=table,
        destination_table=tmp_table,
        mode='unordered',
        spec={
            'title': 'Convert to erasure',
            'combine_chunks': True,
            'force_transform': True,
            'data_size_per_job': data_size_per_job,
            'schema_inference_mode': 'from_output',
            'job_io': {
                'table_writer': {
                    'max_row_weight': 128 * 1024 * 1024,
                    'desired_chunk_size': desired_chunk_size
                }
            }
        }
    )

    log.info('Move %s to %s', tmp_table, table)
    yt.move(tmp_table, table, force=True)
    set_attributes(table, user_attributes, yt)


class Worker(Thread):
    def __init__(self, queue):
        Thread.__init__(self)
        self.queue = queue

    def _get_yt(self):
        return yt.YtClient(
            token=settings.YT_TOKEN,
            proxy='hahn',
        )

    def run(self):
        yt = self._get_yt()
        while True:
            try:
                with yt.Transaction():
                    table, chunk_count = self.queue.get()
                    yt.lock(table, mode="exclusive")
                    log.info('Remerge chunks of %s. Chunks: %s', table, chunk_count)
                    remerge_table(yt, table, chunk_count)

            except Exception:
                log.exception('Error:')

            self.queue.task_done()


def main():
    create_current_file_run_log()

    optparser = OptionParser()

    optparser.add_option('-v', '--verbose', action='store_true')
    optparser.add_option('--workers', dest='workers', default=5, type=int)

    options, args = optparser.parse_args()

    if options.verbose:
        add_stdout_handler(log)

    else:
        yt_logger_config.LOG_LEVEL = 'WARNING'
        reload(yt_logger)

    yth.configure_wrapper(yt)

    log.info('Start')

    current_env = settings.ENVIRONMENT
    if current_env not in ALLOWED_ENVS:
        allowed_envs_str = ', '.join(ALLOWED_ENVS)
        log.info('Current ENVIRONMENT %s. Run only %s allowed.', current_env, allowed_envs_str)
        sys.exit()

    tables = get_tables()

    q = Queue()
    for i in range(options.workers):
        t = Worker(q)
        t.setDaemon(True)
        t.start()

    log.info('Tables to remerge: %d',  len(tables))
    for i, (table, chunk_count) in enumerate(tables):
        q.put((table, chunk_count))

    q.join()
    log.info('Done')
