import calendar
import datetime
import logging
import requests
import time

from dateutil.parser import parse as parse_time
from dateutil.tz import gettz as get_timezone

import sandbox.common.types.task as ctt
from sandbox.sdk2 import (
    Task,
    parameters,
)

from sandbox.projects.yabs.dropstat.base import BaseDropStatTask
from sandbox.projects.yabs.dropstat.base.config import DEFAULT_RETRY_COUNT, MAX_TIME_FROM_START_OF_MONTH
from sandbox.projects.yabs.dropstat.prepare import YabsDropStatPrepareRequest
from sandbox.projects.yabs.dropstat.send import YabsDropStatSendToLogbroker

RETRY_SB_STATUSES = {
    ctt.Status.DELETED,
    ctt.Status.EXCEPTION,
    ctt.Status.FAILURE,
    ctt.Status.STOPPED,
    ctt.Status.TIMEOUT,
}

MOSCOW_TZ = get_timezone('Europe/Moscow')


class YabsDropStatCoordinator(BaseDropStatTask):
    '''Dropstat coordinator
    '''
    class Parameters(BaseDropStatTask.Parameters):
        description = 'Coordinate dropstat requests processing'

        with parameters.Group('Logbroker') as lb_params:
            tvm_src_id = parameters.Integer(
                'TVM source id',
                required=True,
            )
            tvm_dst_id = parameters.Integer(
                'Logbroker TVM id',
                default=2001059,
                required=True,
            )
            tvm_secret = parameters.YavSecret(
                'TVM secret',
                description='Required key: client_secret',
                required=True,
            )

        with parameters.Group('Solomon') as solomon_params:
            solomon_token = parameters.YavSecret(
                'Solomon token secret',
                description='Required key: solomon_token',
                required=True,
            )
            solomon_service = parameters.String(
                'Solomon service',
                default='dropstat_pre',
            )

        with parameters.Group('Scheduling') as sched_params:
            max_selects = parameters.Integer('Max number of concurrent selects', default=10)
            ignore_new = parameters.Bool('Ignore new requests')

        with parameters.Group('Combining') as combine_params:
            combine_max_requests = parameters.Integer('Max number of requests to combine in one YQL', default=10)
            combine_enabled = parameters.Bool('Enable requests combining into single YQL', default=False)
            combine_wait_timeout = parameters.String(
                'Array of combine wait times. A group of n combined requests will wait for array[n-1] minutes for another appropriate request to come',
                default='120, 60, 30, 10'
            )

    def get_combine_wait_timeout(self):
        return [int(s) for s in self.Parameters.combine_wait_timeout.split(',')]

    def init_metrics(self):
        self.now = int(time.time())
        self.metrics = {
            'counters': {
                'not_approved': 0,
                'new': 0,
                'selecting': 0,
                'selecting_sb': 0,
                'selected': 0,
                'merged': 0,
                'sending': 0,
                'done': 0,
                'failed': 0,
                'waiting': 0,
                'bad_row': 0,
                'skipped': 0,
            },
            'update_times': []
        }

    def send_metrics(self):
        times_in_queue = [self.now - t for t in self.metrics['update_times']] or [0]
        sensors = [
            {
                'labels': {'type': 'time_in_queue', 'sensor': 'max'},
                'value': max(times_in_queue),
                'ts': self.now,
            },
            {
                'labels': {'type': 'time_in_queue', 'sensor': 'avg'},
                'value': sum(times_in_queue) / len(times_in_queue),
                'ts': self.now,
            },
        ]

        for key, number in self.metrics['counters'].items():
            sensors.append({
                'labels': {'type': 'counter', 'sensor': key},
                'value': number,
                'ts': self.now,
            })

        json_data = {
            'commonLabels': {'logtype': self.Parameters.log_type},
            'sensors': sensors,
        }

        solomon_token = self.Parameters.solomon_token.data()['solomon_token']
        url = 'http://solomon.yandex.net/api/v2/push?project=yabs&cluster=yabs&service={}'.format(self.Parameters.solomon_service)
        resp = requests.post(url, json=json_data, headers={'Authorization': 'OAuth ' + solomon_token})
        resp.raise_for_status()

    def request_ready_time(self, request):
        date = 0

        if request.params.date_to != 'MAX':
            path = '{}/{}/{}'.format(
                request.log_description.path_prefix,
                request.params.time_resolution,
                request.params.date_to,
            )
            if self.yt_client.exists(path):
                dt = parse_time(self.yt_client.get_attribute(path, 'creation_time'))
                date = calendar.timegm(dt.utctimetuple())
            else:
                return 0
        else:
            date = request.creation_time

        if request.params.wait_node:
            if self.yt_client.exists(request.params.wait_node):
                wait_date = parse_time(self.yt_client.get_attribute(request.params.wait_node, 'modification_time'))
                wait_date_utc = calendar.timegm(wait_date.utctimetuple()) + int(request.params.wait_node_delay) * 60*60
                if wait_date_utc > self.now:
                    logging.info('Skipping request %s by wait time settings untill %d', request.request_id, wait_date_utc)
                    return 0
                date = max(date, wait_date_utc)
            else:
                return 0

        return date

    def skip_by_partner_reports(self, request):
        if not request.params.skip_prev_month:
            return False

        dt_now = datetime.datetime.fromtimestamp(self.now, MOSCOW_TZ)
        month = dt_now.replace(day=1, hour=0, minute=0, second=0)
        month_str = month.strftime('%Y-%m')

        if month_str > request.params.date_from:
            diff = (dt_now - month).total_seconds()
            if diff > MAX_TIME_FROM_START_OF_MONTH:
                logging.info('Skipping request %s by partner reports', request.request_id)
                return True

        return False

    def select_requests(self):
        from yabs.stat.dropstat2.pylibs.common.request import DropStatRequest

        rows = self.yt_meta_client.select_rows("* from [{}]".format(self.meta_path))
        requests = []
        sb_tasks = []
        for row in rows:
            try:
                request = DropStatRequest.from_yt_row(row, log_type=self.Parameters.log_type)
            except Exception:
                logging.warning('Failed to create request object from row %s', row)
                self.metrics['counters']['bad_row'] += 1
            else:
                if not request.update_time:
                    request.update_time = request.creation_time

                table_time = self.request_ready_time(request)
                if table_time > request.update_time:
                    request.update_time = table_time

                if request.state == 'new' and not table_time > 0:
                    self.metrics['counters']['waiting'] += 1
                else:
                    self.metrics['counters'][request.state] += 1
                    if request.state != 'failed' and request.state != 'not_approved':
                        self.metrics['update_times'].append(request.update_time)
                    if request.state == 'selecting':
                        sb_tasks += [request.sb_task_id]

                if row['Cluster'] and row['Cluster'] != self.Parameters.work_yt_proxy:
                    logging.warning('Skipping request %s because it has different cluster')
                elif request.state in ('selecting', 'new', 'sending', 'merged', 'selected'):
                    requests.append(request)

        self.metrics['counters']['selecting_sb'] = len(set(sb_tasks))
        return requests

    def create_task(self, task_class, copy_args, new_args):
        args = {}
        default_args = ['log_type', 'meta_yt_proxy', 'work_yt_proxy', 'production']
        for arg in default_args + copy_args:
            args[arg] = getattr(self.Parameters, arg)

        args['yt_token_secret'] = str(self.Parameters.yt_token_secret)
        args.update(new_args)
        task = task_class(None, owner=self.owner, **args)
        return task

    def check_running_tasks(self, requests, prev_state, next_state):
        tasks = Task.find(id=[req.sb_task_id for req in requests]).limit(len(requests))
        statuses = {task.id: task.status for task in tasks}
        logging.info('Sandbox statuses: %s', statuses)

        updated = []
        for request in requests:
            task_id = request.sb_task_id
            sb_status = statuses.get(task_id, ctt.Status.DELETED)
            if sb_status in RETRY_SB_STATUSES:
                updated.append(request)
                request.retry_count -= 1

                if request.retry_count > 0:
                    logging.info('Retrying request %s because task %s has status %s',
                                 request.request_id, task_id, statuses.get(task_id))
                    request.state = prev_state
                else:
                    logging.info('No retries left for request %s: task %s has status %s',
                                 request.request_id, task_id, statuses.get(task_id))
                    request.state = 'failed'

            elif sb_status == ctt.Status.SUCCESS:
                logging.info('Task %s finished, changing request %s state to %s', task_id, request.request_id, next_state)
                request.state = next_state
                request.retry_count = 0
                updated.append(request)

        return updated

    def create_new_tasks(self, requests, task_class, copy_args, new_args, next_state, combine_requests=False):
        def _run_sb_task(id_str):
            args = {'request_id': id_str, 'description': 'Request {}'.format(id_str)}
            args.update(new_args)
            sb_task = self.create_task(task_class, copy_args, args)
            sb_task.save().enqueue()
            logging.info('Created task %s for request %s', sb_task.id, id_str)

            return sb_task.id

        sb_task_id=None
        if combine_requests:
            sb_task_id = _run_sb_task(','.join([str(r.request_id) for r in requests]))

        for request in requests:
            if not combine_requests:
                sb_task_id = _run_sb_task(str(request.request_id))

            request.state = next_state
            request.sb_task_id = sb_task_id

            if not request.retry_count:
                request.retry_count = DEFAULT_RETRY_COUNT

    def process_new(self, requests):
        from yabs.stat.yabs_yt_audit.api import get_audit_result

        if self.Parameters.ignore_new:
            return

        # Stop processing new requests if logfeller tables audit detected a problem
        audit_class = self.log_description.logfeller_tables_audit_class
        if audit_class and not get_audit_result(audit_class, self.yt_token):
            logging.info('Stopped to process New requests because audit %s failed or was not preformed in time', audit_class.__name__)
            return

        requests = filter(self.request_ready_time, requests)
        max_new = self.Parameters.max_selects - self.select_tasks

        if max_new <= 0:
            logging.info('Max amount of running prepare tasks reached')
            return []
        requests = sorted(requests, key=lambda x: (-x.priority, x.params.date_from))

        logging.info('Requests waiting to be ran', ','.join([str(r.request_id) for r in requests]))

        # Array of arrays of request. Each array inside requests_to_run contains requests to be ran within signgle YQL request
        requests_to_run = []
        # Try to combine requests
        for r in requests:
            if self.Parameters.combine_enabled and not self.log_description.check_can_be_combined([r]) and (r.retry_count > 1 or r.retry_count == 0):
                for r_arr in requests_to_run:
                    for rr in r_arr:
                        def get_date_to(request):
                            return request.params.date_to if request.params.date_to != 'MAX' else request.params.date_to_limit

                        if not self.log_description.check_can_be_combined(r_arr + [r]) and \
                                not (r.params.date_from > get_date_to(rr) or get_date_to(r) < rr.params.date_from) and \
                                len(r_arr) < self.Parameters.combine_max_requests:
                            r_arr += [r]
                            r = None
                            break
                    if not r:
                        break
            if r:
                requests_to_run += [[r]]

        def print_requests_to_run():
            requests_to_run_ids = []
            for rr in requests_to_run:
                requests_to_run_ids += ['[{}]'.format(','.join([str(r.request_id) for r in rr]))]
            return ','.join([rr for rr in requests_to_run_ids])

        # TODO: DO NOT COMBINE FAILED REQUESTS !!
        # Do not run not combined requests for a while, wait for a proper requests to come to be combined
        if self.Parameters.combine_enabled:
            for i in range(len(requests_to_run)):
                requests_to_run[i] = sorted(requests_to_run[i], key=lambda r: -max([r.creation_time, self.request_ready_time(r)]))

            logging.info('Requests to run before timeout filtration: %s', print_requests_to_run())

            for rr in requests_to_run:
                for r in rr:
                    logging.info('{}: {} {}'.format(r.request_id, r.creation_time, self.request_ready_time(r)))

            # Filtration criteria (passed groups go to execution stage):
            # 1. Max amount of combined requests is reached.
            # 2. Diff between ready time of the latest added request and curent time is more the wait time
            wt_arr = self.get_combine_wait_timeout()
            requests_to_run = filter(
                lambda rr: len(rr) >= len(wt_arr) or (time.time() - max([rr[0].creation_time, self.request_ready_time(rr[0])]) > wt_arr[len(rr)-1]*60),
                requests_to_run
            )

        # For execution we need to take the batches with the oldest requests
        for i in range(len(requests_to_run)):
            requests_to_run[i] = sorted(requests_to_run[i], key=lambda r: (-r.priority, r.creation_time))
        requests_to_run = sorted(requests_to_run, key=lambda rr: (-rr[0].priority, rr[0].creation_time))

        requests_to_run = requests_to_run[:max_new]
        logging.info('Final requests to run: %s', print_requests_to_run())

        # Just a cosmetic sort to make it easier to read in YQL GUI
        for i in range(len(requests_to_run)):
            requests_to_run[i] = sorted(requests_to_run[i], key=lambda r: r.params.date_from)

        for rr in requests_to_run:
            process_requests = []
            for req in rr:
                if self.skip_by_partner_reports(req):
                    req.state = 'skipped'
                else:
                    req.cluster = self.Parameters.work_yt_proxy
                    process_requests.append(req)

            if process_requests:
                self.create_new_tasks(process_requests, YabsDropStatPrepareRequest, [], {}, 'selecting', True)

        requests = []
        for rr in requests_to_run:
            requests += rr
        return requests

    def process_selected(self, requests):
        merged = []
        all_merged_ids = set(self.yt_client.list(self.work_dir + '/merged'))
        logging.info('Merged requests: [%s]', all_merged_ids)

        for request in requests:
            # Check that merge scheduler finished processing the request
            if str(request.request_id) in all_merged_ids:
                logging.info('Changing request %s state to merged', request.request_id)
                request.state = 'merged'
                merged.append(request)

        return merged

    def process_selecting(self, requests):
        reqs = self.check_running_tasks(requests, 'new', 'selected')
        self.select_tasks -= len(set([r.sb_task_id for r in reqs]))
        return reqs

    def process_merged(self, requests):
        copy_params = ['tvm_src_id', 'tvm_dst_id']
        args = {'tvm_secret': str(self.Parameters.tvm_secret)}
        self.create_new_tasks(requests, YabsDropStatSendToLogbroker, copy_params, args, 'sending')
        return requests

    def process_sending(self, requests):
        return self.check_running_tasks(requests, 'merged', 'done')

    def on_execute(self):
        self.init_metrics()
        requests = self.select_requests()
        self.select_tasks = self.metrics['counters']['selecting_sb']

        for state in ('selecting', 'new', 'selected', 'sending', 'merged'):
            reqs = filter(lambda req: req.state == state, requests)
            if len(reqs) == 0:
                continue

            logging.info('Processing requests in state %s: %s', state, [r.request_id for r in reqs])
            to_update = getattr(self, 'process_' + state)(reqs)

            # TODO:  Potentially weak place: if status change fails after merged data is sent, repeated send operation may apper
            #        causing the same events to be undone twice (or more)
            if to_update:
                # Create/Delete links for the requests moved to done state
                for req in filter(lambda r: r.state == 'done', to_update):
                    table_path = '{}/{}/{}'.format(self.work_dir, 'merged', req.request_id)
                    link_path = '{}/{}'.format(req.params.link_node, req.request_id)
                    if req.params.link_node and not self.yt_client.exists(link_path):
                        logging.info('Linking table %s to %s because request successfully processed', req.request_id, table_path)
                        self.yt_client.link(table_path, link_path)

                    if req.params.delete_node and self.yt_client.exists(req.params.delete_node):
                        logging.info('Deleting node %s because RequestID %s was set to done', req.params.delete_node, req.request_id)
                        self.yt_client.remove(req.params.delete_node)

                rows = [req.to_yt_row() for req in to_update]
                logging.info('Inserting rows: %s', rows)
                self.yt_meta_client.insert_rows(self.meta_path, rows)

        self.send_metrics()
