import json
import urllib2
import socket
import sys

from sandbox import sdk2
from time import sleep, mktime
from datetime import datetime, timedelta
import sandbox.sandboxsdk.environments as environments

ESCALATION_STATUSES = ['SUCCESS', 'NO_ANSWER', 'USER_BUSY', 'BAD_PHONE', 'REJECTED', 'ANSWER_TIMEOUT',
                       'STOP_ESCALATION', 'TIMED_OUT', 'SKIPPED', 'DOES_NOT_MATTER']


def fetch_url(url, parse_json=True, timeout=15, retries=2, post=None, headers=None):
    data = None

    if not post and headers:
        req = urllib2.Request(url, None, headers)

    elif post and not headers:
        req = urllib2.Request(url, json.dumps(post), None)

    elif post and headers:
        req = urllib2.Request(url, json.dumps(post), headers)

    else:
        req = urllib2.Request(url)

    if not data or data in [' \n', ' ', '\n', ['']]:
        while data is None and retries > 0:
            try:
                socket.setdefaulttimeout(timeout)
                data = urllib2.urlopen(req, timeout=timeout).readlines()
                break

            except urllib2.HTTPError as e:
                if e.code == 404:
                    sys.stderr.write('WARNING: fetch %s failed [%s]\n' % (url, e))
                    data = ['[]']

                else:
                    sys.stderr.write('WARNING: fetch %s failed [%s]\n' % (url, e))

            sleep(1)
            retries -= 1

        if parse_json:
            try:
                data = json.loads(''.join(data))

            except (TypeError, ValueError):
                data = None

    return data


def get_escalations_log(filters, start_time, end_time):
    escalations = []
    escalation_ids = []

    payload = {
        'filters': filters,
        'only_running': False,
        'page_size': 100,
        'page': 0,
    }

    while True:
        data = fetch_url('https://juggler-api.search.yandex.net/v2/escalations/get_escalations_log',
                         post=payload, headers={'Content-Type': 'application/json'})

        if 'escalations' in data and data['escalations'] and data['escalations'][0]['start_time'] >= start_time:
            for e in sorted(data['escalations'], key=lambda e: -e['start_time']):
                if start_time <= e['start_time'] <= end_time and e['escalation_id'] not in escalation_ids:
                    escalations.append(e)
                    escalation_ids.append(e['escalation_id'])
        else:
            break

        payload['page'] += 1

    return sorted(escalations, key=lambda e: e['start_time'])


class PortalNotifications(sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        yt_token = sdk2.parameters.String('Name of sdk2.Vault secret with YT token',
                                          default='YT_TOKEN_MORDA')
        yql_token = sdk2.parameters.String('Name of sdk2.Vault secret with YQL token',
                                           default='YQL_TOKEN_MORDA')
        path_to_escalations = sdk2.parameters.StrictString('Path to tables for escalations',
                                                           regexp='^\/\/home\/([\/\w-]+)+$', required=True,
                                                           default='//home/morda/SRE/escalations/')
        path_to_messages = sdk2.parameters.StrictString('Path to tables for messages', regexp='^\/\/home\/([\/\w-]+)+$',
                                                        required=True,
                                                        default='//home/morda/SRE/messages/')
        resulting_table_name_escalations = sdk2.parameters.StrictString(
            'Name of the resulting table for escalations', regexp='^\w+$', required=True,
            default='escalations')
        resulting_table_name_messages = sdk2.parameters.StrictString('Name of the resulting table for messages',
                                                                     regexp='^\w+$', required=True,
                                                                     default='messages')
        namespaces = sdk2.parameters.StrictString('Namespace',
                                                  regexp='^[\w\.,]+$', required=True, default='portal')
        abc_service = sdk2.parameters.StrictString('Service in ABC',
                                                   regexp='^\w+$', required=True, default='svc_home')

        start_date = sdk2.parameters.StrictString(
            'Start Date YYYY-MM-DD',
            regexp='\d{4}-\d{2}-\d{2}',
            required=True,
        )
        end_date = sdk2.parameters.StrictString(
            'End Date YYYY-MM-DD',
            regexp='\d{4}-\d{2}-\d{2}',
            required=True,
        )

    class Requirements(sdk2.Task.Requirements):
        environments = [
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('yql'),
        ]

    def on_execute(self):
        from yql.api.v1.client import YqlClient
        yql_client = YqlClient(db='hahn', token=sdk2.Vault.data(self.Parameters.yql_token))

        start_date = datetime.strptime(self.Parameters.start_date, '%Y-%m-%d')
        end_date = datetime.strptime(self.Parameters.end_date, '%Y-%m-%d')

        self.write_escalations(start_date, end_date)

        path_to_escalations = self.Parameters.path_to_escalations.rstrip('/')
        path_to_messages = self.Parameters.path_to_messages.rstrip('/')

        query = '''
            $from_date = '{start_date}';
            $to_date = '{end_date}';

            $path = '{path_to_escalations}/{table_name}';

            INSERT INTO $path
            WITH TRUNCATE
            SELECT
                stopped_time,
                check,
                login,
                abc_duty,
                status,
                session_id
            FROM range(`{path_to_escalations}`, $from_date, $to_date)
        '''.format(start_date=self.Parameters.start_date, end_date=self.Parameters.end_date,
                   path_to_escalations=path_to_escalations,
                   table_name=self.Parameters.resulting_table_name_escalations)

        yql_client.query(query).run()

        for day in range(2):
            today = datetime.now() - timedelta(days=day)
            query = '''
                INSERT INTO `{path_to_messages}/{date}`
                WITH TRUNCATE
                SELECT
                    iso_eventtime,
                    checks,
                    status,
                    login,
                    message
                FROM range(`//statbox/juggler-banshee-log`, '{date}', '{date}')
                WHERE status in('CRIT', 'WARN')
                AND abc_service = '{abc_service}'
                AND event_type = 'message_processed'
            '''.format(path_to_messages=path_to_messages, date=today.strftime('%Y-%m-%d'),
                       abc_service=self.Parameters.abc_service)

            yql_client.query(query).run()

        query = '''
            $from_date = '{start_date}';
            $to_date = '{end_date}';

            $path = '{path_to_messages}/{table_name}';

            INSERT INTO $path
            WITH TRUNCATE
            SELECT
                iso_eventtime,
                checks,
                status,
                login,
                message
            FROM range(`{path_to_messages}`, $from_date, $to_date)
        '''.format(start_date=self.Parameters.start_date, end_date=self.Parameters.end_date,
                   path_to_messages=path_to_messages,
                   table_name=self.Parameters.resulting_table_name_messages)

        yql_client.query(query).run()

    def write_escalations(self, start_date, end_date):
        from yt.wrapper import YtClient
        yt_client = YtClient(proxy='hahn', config={'tabular_data_format': 'dsv'},
                             token=sdk2.Vault.data(self.Parameters.yt_token))

        schema = [
            {'name': 'escalation_id', 'type': 'string'},
            {'name': 'stopped_time', 'type': 'string'},
            {'name': 'check', 'type': 'string'},
            {'name': 'login', 'type': 'string'},
            {'name': 'abc_duty', 'type': 'string'},
            {'name': 'status', 'type': 'string'},
            {'name': 'session_id', 'type': 'string'},
        ]

        today = datetime.now()
        today = today.replace(hour=0, minute=0, second=0, microsecond=000000)

        escalations_amount = 0
        skipped_escalations = []

        namespaces = self.Parameters.namespaces
        namespaces = namespaces.split(",")
        filters = []

        for namespace in namespaces:
            filters.append({'namespace': namespace})

        while start_date <= end_date and start_date <= today:
            end_time = start_date.replace(hour=23, minute=59, second=59, microsecond=999999)

            escalations = get_escalations_log(
                filters,
                start_time=mktime(start_date.timetuple()),
                end_time=mktime(end_time.timetuple())
            )

            data = []

            escalations_amount += len(escalations)
            for e in escalations:
                host = e['host']
                service = e['service']

                c = {}
                if 'log' in e and e['log']:
                    for item in e['log']:
                        if 'status' in c and c['status'] == 'STOP_ESCALATION':
                            break

                        for call in item['calls']:
                            if call['status'] == 'STOP_ESCALATION' or not c:
                                c = call
                                if 'abc_duty_login' in item:
                                    c['login'] = item['abc_duty_login']['value']
                                    c['abc_duty'] = item['abc_duty_login']['abc_duty']
                                elif 'simple_login' in item:
                                    c['login'] = item['simple_login']
                                    c['abc_duty'] = 'none'

                            else:
                                break

                if 'status' in c and c['status'] in ESCALATION_STATUSES:
                    login = c['login']
                    stopped_time = datetime.fromtimestamp(c['end_time'])
                    status = c['status']
                    abc_duty = c['abc_duty']
                    session_id = c['session_id']
                    data.append({'escalation_id': e['escalation_id'],
                                 'stopped_time': stopped_time.strftime('%Y-%m-%d %H:%M:%S.%f'),
                                 'check': host + ':' + service,
                                 'login': login,
                                 'abc_duty': abc_duty,
                                 'status': status,
                                 'session_id': session_id})
                else:
                    skipped_escalations.append(e)

            table_path = '{path_to_escalations}/{table_name}'.format(
                path_to_escalations=self.Parameters.path_to_escalations.rstrip('/'),
                table_name=start_date.strftime('%Y-%m-%d'))

            update = False
            if start_date != today and start_date != today - timedelta(days=1):
                if yt_client.exists(table_path):
                    print('The table "{}" already exists!'.format(table_path))
                    start_date = start_date + timedelta(days=1)
                    continue
            else:
                table = {}
                new_data = []
                try:
                    table = yt_client.read_table(table_path)
                except:
                    pass
                if table:
                    for row in table:
                        new_data.append({'escalation_id': row['escalation_id'], 'stopped_time': row['stopped_time'],
                                         'check': row['check'], 'login': row['login'], 'abc_duty': row['abc_duty'],
                                         'status': row['status'], 'session_id': row['session_id']})
                    for item in data:
                        if item not in new_data:
                            new_data.append(item)
                    data = new_data
                    update = True

            if yt_client.exists(table_path):
                yt_client.remove(table_path)

            yt_client.create('table', table_path, attributes={'schema': schema})

            if update:
                print('Update table "{}"'.format(table_path))
            else:
                print('Write table "{}"'.format(table_path))
            yt_client.write_table(table_path, data)

            start_date = start_date + timedelta(days=1)

        print('\nTotal amount of escalations: {}'.format(escalations_amount))
        print('Skipped escalations amount: {}'.format(len(skipped_escalations)))
        print('\nSkipped escalations: {}')
        for e in skipped_escalations:
            print(e)
