from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment

import sys
import time
import datetime
import json
import os
import requests
import logging
import urllib
import kazoo
from kazoo.client import KazooClient, KazooRetry
from random import randrange

import re
from collections import OrderedDict


MAX_TABLES_COUNT = 10000

SHARDS_CACHE_TTL = 60 * 60 # 1 hour

NO_TABLES_TO_PROCESS = '__NO_TABLES_TO_PROCESS__'
COMPLETED = '__COMPLETED__'
IN_PROGRESS = '__IN_PROGRESS__'
FAILED = '__FAILED__'


def hr_json(data):
    if data is None:
        return str(None)
    return json.dumps(OrderedDict(data), indent=4, separators=(',', ': '))


def http_request(url, post_data=None, timeout=5, attempts=10, headers=None, exception_on_fail=True, wait_rerty=5, get_json=True):
    if url.find('http://') == -1 and url.find('https://') == -1:
        url = 'http://{}'.format(url)
    headers = {} if headers is None else headers

    logging.debug('http_request: execute url={url} post_data={data}...'.format(
        url=url,
        data=post_data
    ))
    last_exception = None
    for attempt in range(attempts):
        if attempt and wait_rerty:
            logging.debug('http_request: waitinig retry {time} seconds...'.format(time=wait_rerty))
            time.sleep(wait_rerty)
        try:
            if post_data is None:
                r = requests.get(url, timeout=timeout, headers=headers, verify=False)
            else:
                r = requests.post(url, timeout=timeout, headers=headers, data=json.dumps(post_data), verify=False)
            description = 'code={code} responce={responce}'.format(
                code=r.status_code,
                responce=r.content
            )
            logging.debug('http_request: done {descr}'.format(descr=description))
            if  400 <= r.status_code <= 499:
                last_exception = Exception('user error with {descr}'.format(descr=description))
                break
            if r.status_code not in [200, 202]:
                raise Exception('error with {descr}'.format(descr=description))
            if get_json:
                return json.loads(r.content)
            return r.content
        except Exception as e:
            last_exception = e
            logging.warning('http_request: attempt {attempt}/{attempts_count} exception={exception}'.format(
                attempt=attempt,
                attempts_count=attempts,
                exception=e
            ))
    if exception_on_fail:
        raise last_exception

class ZKClient(object):
    ZK_HOSTS = ('saas-zookeeper1.search.yandex.net:14880,'
                'saas-zookeeper2.search.yandex.net:14880,'
                'saas-zookeeper3.search.yandex.net:14880,'
                'saas-zookeeper4.search.yandex.net:14880,'
                'saas-zookeeper5.search.yandex.net:14880')

    def __enter__(self):
        kazoo.client.log.setLevel(logging.WARNING)
        kazoo.protocol.connection.log.setLevel(logging.WARNING)
        had_failures = False
        self.retry_policy = KazooRetry(max_tries=10, delay=0.5, backoff=2)
        self.kz = KazooClient(hosts=self.ZK_HOSTS,
                              read_only=False,
                              timeout=30,
                              connection_retry=self.retry_policy,
                              command_retry=self.retry_policy)
        for tries in range(3):
            try:
                self.kz.start()
            except Exception as e:
                logging.warning("KazooClient.start failed: " + str(e))
                had_failures = True
            else:
                if had_failures:
                   logging.warning("KazooClient has connected after several attempts.")
                return self.kz
        raise Exception('Failed to establish connection to zookeeper')

    def __exit__(self, type, value, traceback):
        if type is not None:
            logging.debug("KazooClient context manager is destroying, exception: [{}] {}".format(type, value))
        try:
            self.kz.stop()
        except Excetion as e:
            logging.warning("KazooClient.stop failed: " + str(e))
        return False # rethrows


def create_zk_node(path, data=''):
    with ZKClient() as zk:
        if not zk.exists(path):
            zk.create(path, makepath=True)
        zk.set(path, data)

def remove_zk_node(path, recursive=False):
    logging.info('Removing zk node' + (' (recursively!)' if recursive else '') + ': ' + path)
    with ZKClient() as zk:
        if zk.exists(path):
            zk.delete(path, recursive=recursive)
            logging.info('ZK node was removed.')
        else:
            logging.info('ZK node does not exist.')

def get_zk_data(path, description=''):
    logging.info('Getting data from zk: {} path={}...'.format(description, path))
    with ZKClient() as zk:
        if zk.exists(path):
            data = zk.get(path)[0]
            logging.debug('Data from zk raw: {}'.format(data))
            if len(data) > 0:
                data = json.loads(data)
                logging.info('Received data ({}): {}'.format(description, hr_json(data)))
                return data
        else:
            logging.debug('ZK path does not exist: {}'.format(path))
    logging.info('Failed to get data from zk: {} path={}.'.format(description, path))

def timestamp_from_string(ts_str):
    timezone_shift = 0
    template = '%Y-%m-%dT%H:%M:%S'
    if '.' in ts_str:
        template += '.%f'
    if 'Z' in ts_str:
        template += 'Z'
        timezone_shift = 3 * 60 * 60
    return int(time.mktime(datetime.datetime.strptime(ts_str, template).timetuple())) + timezone_shift

def timestamp_to_string(ts, with_ms=False):
    template = '%Y-%m-%dT%H:%M:%S'
    if with_ms:
        template += '.%f'
    return datetime.datetime.fromtimestamp(ts).strftime(template)

class TUpdater:
    def __init__(self, service, ctype, zk_path, fm_host, period, tables_per_shard,
                 task=None, yt_token=None, from_timestamp=0, lb_delivery_path='', ignore_cache=False,
                 allow_empty_shard_on_skip=False, min_sending_ts=0):
        self._zk_service_path = '{}/{}/{}'.format(zk_path, service, ctype)
        self._shards_folder_re = self._gen_shards_folder_re(service, ctype, period)
        self._fm_host = fm_host
        self._period = period
        self._period_sec = 300 if period == '5min' else (24 * 60 * 60)
        self._tables_per_shard = tables_per_shard
        self._task = task
        self._init_yt(yt_token)
        self._from_timestamp = from_timestamp
        if self._from_timestamp % self._period_sec != 0:
            raise Exception('from_timestamp mod period != 0')

        self._state_path = '{}/state'.format(self._zk_service_path)
        self._shards_cache_path = '{}/shards'.format(self._zk_service_path)
        self._last_batch_path = '{}/last_batch'.format(self._zk_service_path)
        self._lb_delivery = (lb_delivery_path != '')
        self._lb_delivery_path = lb_delivery_path
        self._ignore_cache = ignore_cache
        self._allow_empty_shard_on_skip = allow_empty_shard_on_skip
        self._min_sending_ts = min_sending_ts
        self._processor = 'logbroker' if self._lb_delivery else 'logfeller'
        self._yt_subpath = 'ferryman-' + ctype.replace('_', '-') + '/' + service

    def _gen_shards_folder_re(self, service, ctype, period):
        pattern = None
        if period == '1d':
            pattern = '^saas-cloud-{service}-{ctype}-shards-all$'.format(service=service, ctype=ctype)
        elif period == '5min':
            pattern = '^saas-cloud-{service}-{ctype}-shard-[0-9]+\-[0-9]+$'.format(service=service, ctype=ctype)
        else:
            raise Exception('Unsupported period: {}'.format(period))
        return re.compile(pattern)

    def _init_yt(self, yt_token):
        from yt.wrapper import YtClient
        self._yt = YtClient(proxy='arnold', token=yt_token)

    def _get_shards_from_cache(self):
        logging.info('Getting shards from ZK cache...')
        shards = None
        data = get_zk_data(self._shards_cache_path, 'shards from cache')
        if data:
            delay = time.time() - data['timestamp']
            logging.info('Cached shards list updated {:.1f} minutes ago'.format(delay / 60))
            if delay < SHARDS_CACHE_TTL:
                shards = data
            else:
                logging.info('Shards cache expired.')
        logging.info('Shards from cache: {}'.format(hr_json(shards)))
        return shards

    def _save_shards_to_cache(self, shards):
        logging.info('Saving shards to cache...')
        create_zk_node(self._shards_cache_path, json.dumps(shards))
        logging.info('Saved.')

    def _get_shards_from_yt(self):
        logging.info('Getting shards list from YT...')
        shards = {
            'timestamp': int(time.time()),
            'shards': []
        }
        if self._lb_delivery:
            folders = self._yt.list(self._lb_delivery_path)
            for folder in folders:
                if not folder.startswith('shard-'):
                    raise Exception('Incorrect shard prefix: {}. Received shards: {}'.format(folder, ','.format(folders)))
                shards['shards'].append(str(folder))
        else:
            folders = json.loads(self._yt.list('//logs', format='json'))
            for folder in folders:
                if self._shards_folder_re.match(folder):
                    logging.debug('Matched shard folder: {}'.format(folder))
                    shards['shards'].append(folder)
        if len(shards['shards']) == 0:
            raise Exception('Unable to get shards from YT: no folder matches')
        logging.info('Shards from yt: {}'.format(hr_json(shards)))
        return shards

    def _get_shards(self, ignore_cache=False):
        if not ignore_cache:
            shards = self._get_shards_from_cache()
            if shards:
                return shards
        shards = self._get_shards_from_yt()
        if shards is None or len(shards['shards']) == 0:
            raise Exception('Shard data is incorrect: {}'.format(shards))
        self._save_shards_to_cache(shards)
        return shards

    def _get_last_batch_info(self):
        logging.info('Loading last_batch_info from ZK')
        return get_zk_data(self._last_batch_path, 'last batch info')

    def _save_last_batch(self, batch):
        logging.info('Saving last_batch_info to to ZK...')
        create_zk_node(self._last_batch_path, json.dumps(batch))
        logging.info('Saved.')

    def _remove_last_batch(self):
        # We prefer setting an empty value to removal here, because this way it is less prone to failures
        create_zk_node(self._last_batch_path, '')

    def _get_batch_status(self, id):
        result = http_request('{}/get-batch-status?batch={}'.format(self._fm_host, id), timeout=120, attempts=2)
        status = result['status']
        logging.info('Ferryman status for batch {} is: {}'.format(id, status))
        return status

    def _get_batch_result(self, id):
        if id == NO_TABLES_TO_PROCESS:
            return COMPLETED
        status = self._get_batch_status(id)
        if status in ['processing', 'queue']:
            return IN_PROGRESS
        if status in ['final', 'searchable']:
            return COMPLETED
        logging.error('Batch {} is broken, ferryman status: {}'.format(id, status))
        return FAILED

    def _get_state(self):
        logging.info('Loading state from ZK...')
        data = get_zk_data(self._state_path)
        if data is None:
            logging.info('Will create the state anew in ZK: ' + self._state_path)
            data = {'shards': {}}
        return data

    def _save_state(self, state):
        logging.info('Saving state to ZK...')
        create_zk_node(self._state_path, json.dumps(state))
        logging.info('Saved.')

    def _update_state(self, batch):
        state = self._get_state()
        for shard_folder in state['shards']:
            if shard_folder not in batch['shards']:
                logging.warning('A shard is missing from the batch: {}'.format(shard_folder))
        for shard_folder, tables in batch['shards'].items():
            if shard_folder not in state['shards']:
                logging.warning('A new shard detected (it was not there last time): {}'.format(shard_folder))
            prev_ts = state['shards'].get(shard_folder, 0)
            last_table_ts_in_batch = prev_ts
            for table in tables:
                last_table_ts_in_batch = max(table['timestamp'], last_table_ts_in_batch)
            if prev_ts > last_table_ts_in_batch:
                logging.error('State: {}'.format(hr_json(state)))
                logging.error('Batch: {}'.format(hr_json(batch)))
                raise Exception('Cannot mark batch {} as completed: the last timestamp in batch is {} for shard {}, but the state assumes that we have already consumed everything up to {}. Broken state? Please rerun manually.'.format(
                    batch['id'],
                    last_table_ts_in_batch,
                    shard_folder,
                    prev_ts
                ))
            state['shards'][shard_folder] = last_table_ts_in_batch

        state['last_batch'] = batch
        logging.info('Batch completion is recorded. Batch: {}, the new state: {}'.format(batch['id'], hr_json(state)))
        self._save_state(state)


    def _gen_tables_path(self, shard_folder):
        if self._lb_delivery:
            return os.path.join(self._lb_delivery_path, shard_folder)
        if self._period == '1d':
            period_path = self._period
        elif self._period == '5min':
            period_path = 'stream/' + self._period
        else:
            raise Exception('Unsupported period: {}'.format('period'))
        return '//logs/{}/{}'.format(shard_folder, period_path)

    def _gen_table_name(self, timestamp):
        if self._lb_delivery:
            return '{}-300'.format(timestamp)
        if self._period == '1d':
            return timestamp_to_date_str(timestamp)
        elif self._period == '5min':
            return timestamp_to_string(timestamp)
        raise Exception('Unsupported period: {}'.format('period'))

    def _timestamp_from_table(self, table_name):
        if self._lb_delivery:
            return int(table_name.split('-')[0])
        return timestamp_from_string(table_name)

    def _get_tables(self, shard, ts, max_last_ts):
        logging.info('Getting tables for {} that are not older than timestamp {}'.format(shard, ts))
        if self._from_timestamp != 0:
            logging.warning('Will change the timestamp at user request: old={} new={}'.format(ts, self._from_timestamp))
            ts = self._from_timestamp - self._period_sec
        tables = []
        not_empty_tables = 0
        tables_path = self._gen_tables_path(shard)
        logging.debug('YT path to tables: {}'.format(tables_path))

        tables_info = self._yt.list(tables_path, attributes=['id', 'row_count', 'creation_time', 'modification_time'], max_size=MAX_TABLES_COUNT)
        if len(tables_info) == 0:
            logging.error('Found no logbroker tables in {}'.format(tables_path))
            return []
        tables_info.sort()
        logging.debug('Got tables_info count={} for: {} ... {}'.format(len(tables_info), tables_info[0], tables_info[-1]))
        if len(tables_info) == MAX_TABLES_COUNT:
            raise Exception('More than {} tables were found: {}. Consult with Logbroker team.'.format(MAX_TABLES_COUNT, len(tables_info)))
        table_index = None
        while not_empty_tables < self._tables_per_shard:
            if ts == 0:
                table_index = -1
                ts = self._timestamp_from_table(tables_info[0])
                logging.info('First table timestamp is: {}'.format(ts))
            else:
                ts += self._period_sec

            table_name = self._gen_table_name(ts)
            if table_index is None:
                for i, t in enumerate(tables_info):
                    if str(t) == table_name:
                        table_index = i
                        break
                if table_index is None:
                    logging.info('The next table is not found: no table={} in list.'.format(table_name))
                    first_table_ts = self._timestamp_from_table(tables_info[0])
                    last_table_ts = self._timestamp_from_table(tables_info[-1])
                    if first_table_ts > ts:
                        raise Exception('Irrecoverable failure, SOME DATA WAS PROBABLY LOST. ' \
                                        'All tables are newer than the expected timestamp: {}/{} > {}/{}'.format(
                            first_table_ts,
                            tables_info[0],
                            ts,
                            table_name
                        ))
                    if last_table_ts >= ts:
                        raise Exception('Incorrect state: next table (ts={}) is not found, but the last table is newer than that: {}/{} > {}/{}'.format(
                            ts,
                            last_table_ts,
                            tables_info[-1],
                            ts,
                            table_name
                        ))
                    table_index = len(tables_info)
            else:
                table_index += 1

            if table_index == len(tables_info) or table_name != str(tables_info[table_index]):
                logging.info('No next table to process: {}/{}'.format(ts, table_name))
                break

            if ts > max_last_ts:
                logging.info('A table is too new to be processed, will ignore it and all newer tables: {}/{}'.format(ts, table_name))
                break

            info = tables_info[table_index]
            path = os.path.join(tables_path, table_name)
            table = {
                'path': path,
                'name': str(info),
                'timestamp': ts,
                'id': info.attributes['id'],
                'row_count': info.attributes['row_count'],
                'creation_time': info.attributes['creation_time'],
                'modification_time': info.attributes['modification_time'],
            }
            table['modification_timestamp'] = timestamp_from_string(table['modification_time'])
            if table['row_count']:
                not_empty_tables += 1
            if (len(tables) > 0) and (tables[-1]['row_count'] == 0):
                table['skipped_empty_before'] = tables[-1].get('skipped_empty_before', 0) + 1
                tables[-1] = table
            else:
                tables.append(table)
        logging.info('Selected tables: {}'.format(hr_json({'tables': tables})))
        return tables

    def _get_max_table_ts(self):
        return int(time.time() - min(1800, self._period_sec))

    def _truncate_tables_list(self, shard_tables, max_ts):
        # We can ignore shard_tables[0] in the loop below, so at least one table will be present
        l = len(shard_tables)
        last_index = -1 if self._allow_empty_shard_on_skip else 0
        for i in range(len(shard_tables) - 1, last_index, -1):
            if shard_tables[i]['timestamp'] > max_ts:
                l = i
            else:
                break
        if (l < len(shard_tables)):
            return shard_tables[:l]
        else:
            return shard_tables

    def _create_batch(self):
        logging.info('Creating batch...')
        max_last_ts = self._get_max_table_ts()
        state = self._get_state()
        batch = {'shards': {}, 'timestamp': int(time.time())}
        for shard in self._shards['shards']:
            ts = 0
            if shard in state['shards']:
                ts = state['shards'][shard]
            batch['shards'][shard] = self._get_tables(shard, ts, max_last_ts)

        # Ensure that the "last timestamp" is the same for all shards, to avoid rejection from Ferryman (SAAS-5964)
        min_last_ts = None
        for shard in self._shards['shards']:
            shard_tables = batch['shards'][shard]
            if len(shard_tables) > 0:
                last_ts = shard_tables[-1]['timestamp']
                min_last_ts = last_ts if (min_last_ts is None or min_last_ts > last_ts) else min_last_ts

        if self._task and min_last_ts is not None:
            self._task.set_info('Timestamp to be sent: {}'.format(timestamp_to_string(min_last_ts)))

        if min_last_ts is not None:
            for shard in self._shards['shards']:
                shard_tables = batch['shards'][shard]
                if len(shard_tables) > 0:
                    last_ts = shard_tables[-1]['timestamp']
                    if last_ts > min_last_ts:
                        shard_tables = self._truncate_tables_list(shard_tables, min_last_ts)
                        batch['shards'][shard] = shard_tables
                        new_last_ts = shard_tables[-1]['timestamp'] if len(shard_tables) > 0 else 'All tables skipped'
                        logging.warning("Had to exclude some tables from the batch. This changes last_ts in shard {}: {} -> {}".format(shard, last_ts, new_last_ts))

        return batch

    def _send_batch(self, batch):
        logging.info('Sending batch to Ferryman: {}'.format(hr_json(batch)))
        tables = []
        timestamps = set()
        for shard_tables in batch['shards'].values():
            for table in shard_tables:
                if table['row_count'] > 0:
                    logging.debug('Add table to batch: {}'.format(table))
                    ts = table['modification_timestamp']
                    if ts < self._min_sending_ts:
                        ts = self._min_sending_ts
                    while ts in timestamps:
                        ts += 1
                    timestamps.add(ts)
                    tables.append({
                        'Path': table['path'],
                        'Timestamp': ts * 1000 * 1000,
                        'Delta': True,
                        'Format': self._processor
                    })
        logging.info(hr_json({'Tables to process': tables}))
        if len(tables) == 0:
            batch['id'] = NO_TABLES_TO_PROCESS
            if self._task:
                self._task.set_info('Nothing to do - there are no tables to process.')
        else:
            encoded_tables = urllib.quote_plus(json.dumps(tables))
            post_data = None
            if len(encoded_tables) < 64000:
                batch['request'] = '{}/add-full-tables?tables={}'.format(self._fm_host, encoded_tables)
            else:
                batch['request'] = '{}/add-full-tables?'.format(self._fm_host)
                post_data = tables
            logging.debug('Filled batch: {}'.format(hr_json(batch)))
            logging.debug('Length of request: {}'.format(len(batch['request'])))
            result = http_request(batch['request'], timeout=150, attempts=2, post_data=post_data)
            logging.info('Add tables result: {}'.format(hr_json(result)))
            if 'batch' not in result:
                raise Exception('Failed to start a Ferryman batch, reason: {}'.format(result))
            batch_status_url = 'http://{}/get-batch-status?batch={}'.format(self._fm_host, result['batch'])
            if self._task:
                self._task.set_info('A Ferryman batch was created: <a href="{url}">{url}</a>'.format(url=batch_status_url), do_escape=False)

            batch['id'] = result['batch']
        self._save_last_batch(batch)

    def _fill_age(self):
        self.age = None
        now = time.time()
        state = self._get_state()
        for shard, ts in state['shards'].items():
            if shard.startswith('shard-'):
                if self.age is None or self.age < (now - ts):
                    self.age = now - ts
        logging.info('Age updated: {}'.format(self.age))

    def do(self):
        self._fill_age()
        batch = self._get_last_batch_info()
        if batch:
            result = self._get_batch_result(batch['id'])
            if self._task and batch.get('timestamp'):
                self._task.set_info('Last batch creation time: {}'.format(timestamp_to_string(batch['timestamp'])))
                try:
                    for shard, sh_tables in batch['shards'].items():
                        if len(sh_tables):
                            self._task.set_info('Last batch table ts: {}'.format(sh_tables[0]['modification_time']))
                            break
                except:
                    pass
            if result == IN_PROGRESS:
                logging.info('The last batch is still in progress.')
                return
            if result == FAILED and not (self._from_timestamp != 0):
                self._task.set_info(
                    'Probable <a href="https://yt.yandex-team.ru/arnold/operations?filter=/{subpath}&user=robot-saas-ferryman">operations</a>'.format(subpath=self._yt_subpath))
                raise Exception('The last batch failed! please rerun the task manually with a timestamp')
            if result == COMPLETED:
                self._update_state(batch)
            else:
                logging.warning("Will drop the record about last batch from Zookeeper")
            self._remove_last_batch()
        else:
            logging.info('There was no previous batch')

        self._shards = self._get_shards(self._ignore_cache)
        batch = self._create_batch()
        self._send_batch(batch)
        logging.info('Finished.')

    def clear_shards_cache(self):
        remove_zk_node(self._shards_cache_path)


class UploadLogfellerData2(sdk2.Task):
    """ Upload logfeller data to ferryman v2 """

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 100

        environments = [PipEnvironment('yandex-yt')]

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):
        kill_timeout = 900

        service = sdk2.parameters.String('Service', required=True)
        ctype = sdk2.parameters.String('Ctype', required=True)
        zk_path = sdk2.parameters.String('zk_path', required=True, default='/logbroker/ferryman/')
        fm_host = sdk2.parameters.String('Ferryman host', required=True)
        period = sdk2.parameters.String('Period', required=True, default='5min')
        tables_per_shard = sdk2.parameters.Integer('Max tables per shard in batch', required=True, default=1)
        from_timestamp = sdk2.parameters.Integer('Upload tables from timestamp', default=0)
        lb_delivery_path = sdk2.parameters.String('Logbroker YT delivery path', default='')
        ignore_cache = sdk2.parameters.Bool('Ignore cache', default=False)
        allow_empty_shard_on_skip = sdk2.parameters.Bool('Allow empty shard on skip', default=True)
        min_sending_ts = sdk2.parameters.Integer('Set timestamp on send >=', default=0)

    def on_execute(self):
        yt_token = sdk2.Vault.data(self.owner, 'YT_TOKEN_ARNOLD')

        updater = TUpdater(
            self.Parameters.service,
            self.Parameters.ctype,
            self.Parameters.zk_path,
            self.Parameters.fm_host,
            self.Parameters.period,
            self.Parameters.tables_per_shard,
            self,
            yt_token,
            self.Parameters.from_timestamp,
            self.Parameters.lb_delivery_path,
            self.Parameters.ignore_cache,
            self.Parameters.allow_empty_shard_on_skip,
            self.Parameters.min_sending_ts
        )
        try:
            updater.do()
        except Exception as e:
            if updater.age is None or updater.age > 60 * 60:
                raise
            else:
                ex_type, ex_value, ex_traceback = sys.exc_info()
                logging.error(e, exc_info=True)
                msg = "Task failed! Forcing OK state due to small updater age. Exception: [{}] {}".format(ex_type, ex_value)
                self.set_info(msg, do_escape=False)

