# -*- coding: utf-8 -*-

import os
import time
import random
import tempfile
import threading
import itertools
from datetime import datetime, timedelta

from sandbox.sandboxsdk import paths, process
from sandbox.sandboxsdk.errors import SandboxSubprocessTimeoutError

from sandbox.projects.resource_types import YABS_YTSTAT_LB_READER
from sandbox.projects.common.yabs.app_options import AppOptions, APP_OPTIONS_YT_PROXY
from sandbox.projects.common.yabs.ytstat import get_and_sync_autoreleased_resource, get_version_from_yt

import sandbox.projects.common.yabs.lbyt_reader.utils as lbyt_utils
from sandbox.projects.common.yabs.lbyt_reader.sandbox_task import YabsLbYtSandboxTask
from sandbox.projects.common.yabs.lbyt_reader.job_yt_client import (
    READER_STATE, JobYtClient, DataProxy,
    get_task_offsets, get_table_name, parse_table_name,
)
from sandbox.projects.common.yabs.lbyt_reader.config import (
    LOGTYPES, MANAGER_JOBS, READER_JOBS,
    DEFAULT_TASK_FREQUENCY,
    RECENT_ERRORS_INTERVAL, KEEP_INPUT_TABLES,
)


class YabsLbYtReader(YabsLbYtSandboxTask):
    type = 'YABS_LBYT_READER'
    _lbyt_jobs = READER_JOBS

    @property
    def destination(self):
        return self.yt_cluster.path_prefix + MANAGER_JOBS[self.config['job']]['dest_path']

    def job_init_before_lock(self):
        job_cwd = tempfile.mkdtemp(
            dir=os.getcwd(),
            prefix='{}_'.format(self.job),
        )
        os.chdir(job_cwd)

        self.reader_version = None
        self.reader_executable = None
        if not self.download_reader_resource():
            return False

        self.yt_token_file = tempfile.NamedTemporaryFile()
        self.yt_token_file.write(str(self.yt_cluster.token))
        self.yt_token_file.flush()

        self.app_options_file = AppOptions.RESULT_FILE
        self.app_options_yt_client = JobYtClient.configure_yt_client(
            APP_OPTIONS_YT_PROXY,
            self.yt_cluster.token,
        )

        return True

    def job_init_after_lock(self):
        smoothing_sleep = random.random() * self.config['max_smoothing_sleep']
        self.logger.info('Waiting %.1f seconds before first iteration', smoothing_sleep)
        time.sleep(smoothing_sleep)
        return True

    def need_work(self):
        self.yt_cluster.read_control_state()
        return self.yt_cluster.control_state.get(READER_STATE)

    def do_job_work(self):
        self.base_metric = self.base_metric(hostname=self.yt_cluster.name)

        tasks_to_process = self.get_tasks_to_process()
        if tasks_to_process is None:
            return False, 0

        ready_tasks, old_done_task_nums, ldbr_task_info = tasks_to_process
        self.commit_offsets(ldbr_task_info)
        if ready_tasks:
            oldest_ready_time_mark = min(
                task_info.time_mark
                for task_info in ready_tasks.values()
            )
            self.add_metric('lag', time.time() - oldest_ready_time_mark / 1000.0)
        else:
            self.add_metric('lag', 0)

        # no explicit read_control_state – use value cached in need_work
        skip_by_control = self.yt_cluster.control_state.get('skip_ready')
        if self.config['only_for_quorum'] or skip_by_control:
            if not skip_by_control:
                self.logger.warn('Skip mode is forced by only_for_quorum')

            self.logger.info('Marking %d tasks as skipped', len(ready_tasks))
            self.yt_cluster.mark_tasks(ready_tasks.keys(), 'skipped')
            self.yt_cluster.flush_tasks()

            self.logger.info('Deleting %d old tasks', len(old_done_task_nums))
            self.yt_cluster.delete_tasks(old_done_task_nums)
            return True, self.start_timestamp + self.NOTHING_TO_DO_SLEEP_TIME

        started_tasks, success = self.download_tasks(ready_tasks)
        self.add_metric('tasks.started', started_tasks)

        cached_tables = self.get_tables_info(self.yt_cluster.cache)
        if cached_tables is not None:
            cached_tables = lbyt_utils.forget_attributes(cached_tables)
            self.mark_done(ready_tasks, cached_tables)
            # Probably need move: current ready tasks and tasks from cache,
            # which is done, but move failed. So let check them all.
            probably_need_move = set(ready_tasks).union(
                parse_table_name(table_name)['num']
                for table_name in cached_tables
            )
            moved = self.move_done(probably_need_move, cached_tables)
            self.add_metric('tasks.moved', moved)

        error_tables = self.get_tables_info(self.yt_cluster.errors)
        if error_tables is not None:
            self.add_metric('parsing_errors', sum(
                table.attributes['row_count']
                for table in error_tables
            ))
            recent_since = datetime.utcnow() - timedelta(minutes=RECENT_ERRORS_INTERVAL)
            self.add_metric('recent_parsing_errors', sum(
                table.attributes['row_count']
                for table in error_tables
                if lbyt_utils.parse_yt_datetime(table.attributes['creation_time']) > recent_since
            ))

        self.logger.info('Deleting %d old tasks', len(old_done_task_nums))
        self.yt_cluster.delete_tasks(old_done_task_nums)

        return success, 0

    def get_tasks_to_process(self):
        self.yt_cluster.read_tasks()
        if self.yt_cluster.tasks is None:
            self.logger.error('No tasks info from yt, can not proceed')
            return None

        ready_tasks = {}
        # LDBR means "Latest Done Before all Ready"
        ldbr_task_info = None
        old_done_task_nums = []
        delete_since = datetime.utcnow() - timedelta(hours=self.config['keep_completed_tasks'])
        for task_info in sorted(self.yt_cluster.tasks, key=lambda task_info: task_info.num):
            if not ready_tasks and task_info.state in ('done', 'skipped'):
                ldbr_task_info = task_info
                timestamp = task_info.time_mark / 1000
                if datetime.utcfromtimestamp(timestamp) < delete_since:
                    old_done_task_nums.append(task_info.num)
            if task_info.state == 'ready':
                task_info.data_proxy.request()
                ready_tasks[task_info.num] = task_info
        # We always want to keep at least one task in the table
        old_done_task_nums = old_done_task_nums[:-1]
        if ldbr_task_info is not None:
            ldbr_task_info.data_proxy.request()
        if not DataProxy.load_requested():
            return None
        self.add_metric('tasks.ready', len(ready_tasks))
        return ready_tasks, old_done_task_nums, ldbr_task_info

    def commit_offsets(self, task_info):
        if not task_info or self.config['only_for_quorum']:
            return
        offsets = get_task_offsets(task_info.data_proxy.get(), for_next=True)
        self.lb_cluster.commit_offsets(self.config['client'], offsets)

    # Returns (started_tasks, success)
    #     started_tasks - number of tasks that were started
    #     success - tasks were not started because of valid external conditions or
    #               all tasks successfully finished
    def download_tasks(self, tasks):
        if not tasks:
            self.logger.warn('No ready tasks, nothing to download')
            return 0, True
        if not self.download_reader_resource() or not self.download_hostoptions_resource():
            return 0, False

        cached_tables = self.get_tables_info(self.yt_cluster.cache)
        done_tables = self.get_tables_info(self.destination)
        if cached_tables is None or done_tables is None:
            self.logger.error('No resources usage info from yt, can not proceed')
            return 0, False

        tables_count = len(cached_tables) + len(done_tables)
        self.add_metric('tables_count.limit', self.config['count'])
        self.add_metric('tables_count.cached', len(cached_tables))
        self.add_metric('tables_count.done', len(done_tables))
        self.add_metric('tables_count.total', tables_count)

        tables_size = sum(
            table.attributes['resource_usage']['disk_space']
            for table in cached_tables + done_tables
        )
        self.add_metric('disk_space.quota', 1024*1024 * self.config['quota'])
        self.add_metric('disk_space.usage', tables_size)
        if tables_count >= self.config['count']:
            self.logger.warn('Can not start download: count limit reached')
            return 0, True
        tables_size_quota = 1024*1024 * self.config['quota']
        if tables_size >= tables_size_quota:
            self.logger.warn('Can not start download: disk space quota for the topic exceeded. Quota size %d' % tables_size_quota)
            return 0, True

        threads = []
        limit_reached = False
        cached_tables = lbyt_utils.forget_attributes(cached_tables)
        start_reader_tracker = lbyt_utils.ThreadingSuccessTracker(self.start_reader)
        for task_num in sorted(tasks):
            task_info = tasks[task_num]
            for logtype in task_info.data_proxy.get():
                table_name = get_table_name(self.config['job'], task_info, logtype)
                if table_name in cached_tables:
                    continue
                thread = threading.Thread(
                    target=start_reader_tracker,
                    args=(task_info, logtype),
                )
                threads.append(thread)
                thread.start()
                tables_count += 1
                limit_reached = (
                    tables_count >= self.config['count'] or
                    len(threads) >= self.config['max_simultaneous_downloads']
                )
                if limit_reached:
                    break
            if limit_reached:
                break
        started_tasks = len(threads)
        self.logger.info(
            'Download of %d tasks is requested, started downloading of %d tables',
            len(tasks), started_tasks,
        )
        if limit_reached:
            self.logger.warn('Count limit or max simultaneous downloads reached')

        for thread in threads:
            thread.join()

        if start_reader_tracker.fail:
            self.logger.error('All read operations failed')
        elif not start_reader_tracker.success:
            self.logger.error('Not all read operations succeeded')
        return started_tasks, start_reader_tracker.success

    def mark_done(self, tasks, cached_tables):
        done_task_nums = []
        for task_num, task_info in tasks.iteritems():
            in_cache = [
                (get_table_name(self.config['job'], task_info, logtype) in cached_tables)
                for logtype in task_info.data_proxy.get()
            ]
            if all(in_cache):
                done_task_nums.append(task_num)
        self.add_metric('tasks.mark_done', len(done_task_nums))
        if not done_task_nums:
            return
        self.logger.info(
            'Marking tasks #[%s] as done on yt cluster %s',
            ', '.join(map(str, done_task_nums)), self.yt_cluster.name,
        )
        self.yt_cluster.mark_tasks(done_task_nums, 'done')
        self.yt_cluster.flush_tasks()

    def move_done(self, task_nums, cached_tables):
        if not task_nums:
            return 0

        self.yt_cluster.read_tasks()  # update states after mark_done
        if self.yt_cluster.tasks is None:
            return 0
        task_states = {
            task_info.num: task_info.state
            for task_info in self.yt_cluster.tasks
        }

        moved_count = 0
        for task_num in sorted(task_nums):
            moved = False
            if task_states.get(task_num) == 'done':
                move_pairs = []
                for table_name in cached_tables:
                    if parse_table_name(table_name)['num'] == task_num:
                        move_pairs.append((
                            '{}/{}'.format(self.yt_cluster.cache, table_name),
                            '{}/{}'.format(self.destination, table_name),
                        ))
                self.logger.info(
                    'Moving tables for task #%s to done on yt cluster %s',
                    task_num, self.yt_cluster.name,
                )
                moved = self.yt_cluster.move_tables(move_pairs)
                moved_count += (1 if moved else 0)
            if self.config['keep_sequential'] and not moved:
                break
        return moved_count

    def get_tables_info(self, target_path):
        try:
            if not self.yt_cluster.yt.exists(target_path):
                self.yt_cluster.yt.mkdir(target_path, recursive=True)
            tables_info = [
                table
                for table in self.yt_cluster.yt.list(
                    target_path, attributes=['creation_time', 'resource_usage', 'row_count'],
                )
                if table.startswith(self.config['job'] + '.')
            ]
        except Exception as err:
            self.logger.warn(
                'Failed to get tables in %s on yt cluster %s: %s',
                target_path, self.yt_cluster.name, err,
            )
            return None
        return tables_info

    def download_reader_resource(self):
        if self.reader_version is None or get_version_from_yt(YABS_YTSTAT_LB_READER, self.yt_cluster.yt) != self.reader_version:
            self.logger.info('Downloading reader resource for %s', self.yt_cluster.name)
            try:
                resource = get_and_sync_autoreleased_resource(
                    YABS_YTSTAT_LB_READER, self.yt_cluster.yt, self,
                )
                self.reader_version = resource.attributes['version']
                self.reader_executable = resource.path
                self.logger.info('Using resource %s for cluster %s, release version %s', resource.id, self.yt_cluster.name, resource.attributes['version'])
            except Exception as err:
                self.logger.error(
                    'Reader release for %s was not found, can not proceed: %s',
                    self.yt_cluster.name, err,
                )
                return False
            self.logger.info('Reader resource for %s downloaded', self.yt_cluster.name)
        else:
            self.logger.info('Reader for %s already downloaded', self.yt_cluster.name)

        return True

    def download_hostoptions_resource(self):
        self.logger.info('Downloading hostoptions resource from %s', APP_OPTIONS_YT_PROXY)
        try:
            AppOptions.get_app_options_resource_from_yt(self.app_options_yt_client)
            AppOptions.make_app_options_file(
                self.yt_cluster.proxy,
                result_name=self.app_options_file,
            )
        except Exception as err:
            self.logger.error('Hostoptions resource download failed, can not proceed: %s', err)
            return False
        self.logger.info('Hostoptions resource downloaded from %s', APP_OPTIONS_YT_PROXY)

        return True

    def start_reader(self, task_info, logtype):
        try:
            offsets_per_pqclient = LOGTYPES.get(logtype, {}).get('max_offsets_per_pqclient')
            data_format = LOGTYPES.get(logtype, {}).get('data_format')

            frequency = LOGTYPES.get(logtype, {}).get('frequency', DEFAULT_TASK_FREQUENCY)
            job_timeout = frequency * self.config['timeout_coefficient']

            input_rows = list(itertools.chain.from_iterable(
                self.split_task(read_task, offsets_per_pqclient)
                for read_task in task_info.data_proxy.get()[logtype]
            ))
            random.shuffle(input_rows)  # read of one partition from different nodes is better

            table_name = get_table_name(self.config['job'], task_info, logtype)
            input_table = '{}/{}'.format(self.yt_cluster.input, table_name)
            output_table = '{}/{}'.format(self.yt_cluster.cache, table_name)
            errors_table = '{}/{}'.format(self.yt_cluster.errors, table_name)

            input_expiration_time = datetime.utcnow() + timedelta(hours=KEEP_INPUT_TABLES)
            input_attributes = {'expiration_time': input_expiration_time.isoformat() + 'Z'}
            input_attributes.update(self.config['tmp_input_table_attributes'])

            cmd = [
                self.reader_executable,
                '--yt', self.yt_cluster.proxy,
                '--token-file', self.yt_token_file.name,
                '--jobs', str(max(len(input_rows), 1)),
                '--job-timeout', str(job_timeout),
                '--job-data-interval', str(frequency),
                '--logbroker', self.lb_cluster.proxy,
                '--client', self.config['client'],
                '--logtype', logtype,
                '--input', input_table,
                '--output', output_table,
                '--errors', errors_table,
                '--errors-expire', str(self.config['err_expiration_time']),
                '--attribute', 'batch_id:{}'.format(int(task_info.num)),
                '--attribute', 'batch_timestamp:{}'.format(int(task_info.time_mark / 1000)),
                '--options-file', self.app_options_file,
            ]
            if data_format:
                cmd.extend(['--format', data_format])
            if self.config['pool']:
                cmd.extend(['--pool', self.config['pool']])
            if self.config['pool_tree']:
                cmd.extend(['--pool-tree', self.config['pool_tree']])
            if self.config['file_storage']:
                cmd.extend(['--file-storage', self.config['file_storage']])
            if self.config['stats_path']:
                stats_path = self.config['stats_path']
                if not stats_path.startswith('//'):
                    stats_path = self.yt_cluster.path_prefix + stats_path
                stats_table = '{}/{}'.format(stats_path, table_name)
                cmd.extend(['--stats', stats_table])
            if self.config['fat_keys_path']:
                fat_keys_path = self.config['fat_keys_path']
                if not fat_keys_path.startswith('//'):
                    fat_keys_path = self.yt_cluster.path_prefix + fat_keys_path
                cmd.extend(['--fat-keys', fat_keys_path])

            environment = os.environ.copy()
            if self.config['yt_log_level']:
                environment['YT_LOG_LEVEL'] = self.config['yt_log_level']

            # We need our own connection because Yt is not thread-safe
            self.logger.info('Creating new yt client for cluster %s', self.yt_cluster.name)
            yt = self.yt_cluster.get_yt_connection()

            self.logger.info(
                'Preparing reader input table %s on cluster %s',
                input_table, self.yt_cluster.name,
            )
            with yt.Transaction():
                yt.create('table', input_table, attributes=input_attributes, recursive=True, force=True)
                yt.write_table(input_table, input_rows)

            log_path = paths.get_unique_file_name(
                paths.get_logs_folder(),
                get_table_name(self.job, task_info, logtype) + '.log',
            )
            self.logger.info(
                'Starting lbyt_reader for table %s on yt cluster %s, see log at:\n  %s',
                output_table, self.yt_cluster.name, '{}/{}'.format(
                    self.logs_url, os.path.basename(log_path),
                ),
            )
            with open(log_path, 'w') as log_file:
                process.run_process(
                    cmd,
                    check=True,
                    stdout=log_file,
                    stderr=log_file,
                    environment=environment,
                    # give 2 tries to every job, +15 for client and scheduling time
                    timeout=(2 * job_timeout + 15),
                )
            self.logger.info('lbyt_reader finished for table %s on yt cluster %s', output_table, self.yt_cluster.name)
        except Exception as err:
            self.logger.warn(
                'Failed to download table %s on yt cluster %s: %s',
                output_table, self.yt_cluster.name, err,
            )
            if isinstance(err, SandboxSubprocessTimeoutError):
                # Let process write transaction expire
                time.sleep(11)
            return False
        return True

    @staticmethod
    def split_task(read_task, offsets_limit):
        if read_task['limit'] < 1:
            return
        if offsets_limit is None or offsets_limit < 1:
            yield read_task
            return
        required_count = (read_task['limit'] + offsets_limit - 1) / offsets_limit
        greater_count = read_task['limit'] % required_count
        less_limit = read_task['limit'] / required_count
        chunks = [less_limit+1] * greater_count + [less_limit] * (required_count - greater_count)
        for chunk in chunks:
            new_task = read_task.copy()
            new_task['limit'] = chunk
            yield new_task
            read_task['offset'] += chunk
            read_task['limit'] -= chunk
        assert read_task['limit'] == 0, 'something went really wrong in split_task'


__Task__ = YabsLbYtReader
