# -*- coding: utf-8 -*-

from __future__ import print_function

import argparse
import yt.wrapper as yt

import time
from collections import OrderedDict


BUS_CONFIGS = {
    'offer_bus': {
        'TTL': 6 * 60 * 60 * 1000,
    },
    'offerreq_bus': {
        'TTL': 2 * 60 * 60 * 1000,
    },
    'pricecheckreq_bus': {
        'TTL': 2 * 60 * 60 * 1000,
    },
    'offer_data_storage': {
        'TTL': 6 * 60 * 60 * 1000,
        'Keys': ['MessageId'],
    },
    'warmer_state_bus': {
        'TTL': 48 * 60 * 60 * 1000,
    },
    'pricechecker_state_bus': {
        'TTL': 6 * 60 * 60 * 1000,
    },
    'oc_interconnect_bus': {
        'TTL': 1 * 60 * 60 * 1000,
    },
    'search_flow_offer_data_storage': {
        'TTL': {
            # When updating ttl here don't forget to also update setting here: https://a.yandex-team.ru/arc/trunk/arcadia/travel/hotels/redir/proto/config.proto?rev=r8177855#L24
            'prod': 48 * 60 * 60 * 1000,
            'testing': 12 * 60 * 60 * 1000,
            'dev': 12 * 60 * 60 * 1000,
        },
        'Keys': ['MessageId'],
        'Medium': 'ssd_blobs',
    },
    'outdated_offer_bus': {
        'TTL': 24 * 60 * 60 * 1000,
    },
}

SET_MEDIUMS = {
    'dev': False,
    'testing': True,
    'prod': True,
}

CLUSTER_NAME_TO_BUNDLE = {
    'dev': {},
    'testing': {
        'hahn': None,
        'arnold': None,
        'seneca-sas': 'travel',
        'seneca-vla': 'travel',
        'seneca-man': 'travel'
    },
    'prod': {
        'hahn': 'travel-prod',
        'arnold': 'travel-prod',
        'seneca-sas': 'travel-prod',
        'seneca-vla': 'travel-prod',
        'seneca-man': 'travel-prod'
    },
}

SCHEMA = OrderedDict([
    ('MessageId', 'string'),
    ('Timestamp', 'uint64'),
    ('ExpireTimestamp', 'uint64'),
    ('MessageType', 'string'),
    ('Codec', 'uint64'),
    ('Bytes', 'string'),
])


def create_dir(path):
    if not yt.exists(path):
        print('Creating dir', path)
        yt.create('map_node', path)


def get_schema(bus_config):
    schema = []
    for field_name, field_type in SCHEMA.items():
        entry = {'name': field_name, 'type': field_type}
        if field_name in bus_config.get('Keys', []):
            entry['sort_order'] = 'ascending'
        schema.append(entry)
    return schema


def with_retry(action, attempts=10, backoff=5):
    att = 0
    while True:
        try:
            action()
            return
        except:
            print("Failed to do action. Wait %s sec and retry" % backoff)
            time.sleep(backoff)
            att += 1
            if att >= attempts:
                raise


def get_ttl(bus_config, env):
    ttl = bus_config['TTL']
    if isinstance(ttl, dict):
        return ttl[env]
    return ttl


def create_bus(cluster, table_path, bus_config, args):
    print('Creating table', table_path)
    yt.create('table', table_path, recursive=True, attributes={
        'dynamic': True,
        'min_data_versions': 0,
        'max_data_versions': 1,
        'min_data_ttl': 0,
        'max_data_ttl': get_ttl(bus_config, args.env),
        'schema': get_schema(bus_config)
    })
    bundle = CLUSTER_NAME_TO_BUNDLE[args.env].get(cluster)
    if bundle is not None:
        print('Setting bundle to {} for {}'.format(bundle, table_path))
        yt.set(yt.ypath_join(table_path, '@tablet_cell_bundle'), bundle)
    if SET_MEDIUMS[args.env]:
        medium = bus_config.get('Medium', 'default')
        if medium != 'default':
            print('Setting medium to {} for {}'.format(medium, table_path))
            yt.set(yt.ypath_join(table_path, '@primary_medium'), medium)
    print('Mounting', table_path)
    yt.mount_table(table_path)


def remove_bus(cluster, table_path, bus_config, args):
    if input('Do you really want to remove {} on {}? Type "YES" to confirm: '.format(table_path, cluster)) == 'YES':
        print('Removing', table_path)
        yt.remove(table_path)


def alter_bus(cluster, table_path, bus_config, args):
    attr_list = yt.get(yt.ypath_join(table_path, '@schema'))
    actual_keys = [el['name'] for el in attr_list if el.get('sort_order') == 'ascending']
    actual_schema = {el['name']: el['type'] for el in attr_list}
    schema_changed = actual_keys != bus_config.get('Keys', []) or actual_schema != SCHEMA
    ttl_changed = yt.get(yt.ypath_join(table_path, '@max_data_ttl')) != get_ttl(bus_config, args.env)
    if not schema_changed and not ttl_changed:
        print("Schema and TTL not changed")
        return
    if input('Do you really want to alter {} on {}? Type "YES" to confirm: '.format(table_path, cluster)) != 'YES':
        return
    print('Unmounting %s' % table_path)
    yt.unmount_table(table_path)
    if ttl_changed:
        print("Changing TTL to %s" % get_ttl(bus_config, args.env))
        with_retry(lambda: yt.set(yt.ypath_join(table_path, '@max_data_ttl'), get_ttl(bus_config, args.env)))
        print("TTL changed successfully")
    if schema_changed:
        print("Altering table schema")
        with_retry(lambda: yt.alter_table(table_path, get_schema(bus_config)))
        print("Schema altered successfully")
    print('Mounting %s' % table_path)
    with_retry(lambda: yt.mount_table(table_path))


def remount_bus(cluster, table_path):
    if input('Do you really want to remount {} on {}? Type "YES" to confirm: '.format(table_path, cluster)) != 'YES':
        return
    print('Unmounting %s' % table_path)
    yt.unmount_table(table_path)
    print('Mounting %s' % table_path)
    with_retry(lambda: yt.mount_table(table_path))


def update_bundles(cluster, table_path, args):
    bundle = CLUSTER_NAME_TO_BUNDLE[args.env].get(cluster, 'default')
    curr_bundle = yt.get(yt.ypath_join(table_path, '@tablet_cell_bundle'))

    if bundle == curr_bundle:
        print('Skipping %s, bundle is correct' % table_path)
        return

    if input('Do you really want to update bundles of {} on {}? Table will be REMOUNTED! Type "YES" to confirm: '.format(table_path, cluster)) != 'YES':
        return

    print('Unmounting %s' % table_path)
    yt.unmount_table(table_path)

    print('Setting bundle to {} for {}'.format(bundle, table_path))
    with_retry(lambda: yt.set(yt.ypath_join(table_path, '@tablet_cell_bundle'), bundle))

    print('Mounting %s' % table_path)
    with_retry(lambda: yt.mount_table(table_path))


def update_medium(cluster, table_path, bus_config):
    medium = bus_config.get('Medium', 'default')
    curr_medium = yt.get(yt.ypath_join(table_path, '@primary_medium'))

    if medium == curr_medium:
        print('Skipping %s, medium is correct' % table_path)
        return

    if input('Do you really want to update medium of {} on {}? Table will be REMOUNTED! Type "YES" to confirm: '.format(table_path, cluster)) != 'YES':
        return

    print('Unmounting %s' % table_path)
    yt.unmount_table(table_path)

    print('Setting medium to {} for {}'.format(medium, table_path))
    with_retry(lambda: yt.set(yt.ypath_join(table_path, '@primary_medium'), medium))

    print('Mounting %s' % table_path)
    with_retry(lambda: yt.mount_table(table_path))


def main():
    parser = argparse.ArgumentParser(add_help=True, description='Script for creating/removing message buses on several clusters')
    parser.add_argument('action', choices=('create', 'remove', 'alter', 'remount', 'update-bundles', 'update-medium'), help='either create or remove buses')
    parser.add_argument('--env', '-e', choices=('dev', 'testing', 'prod'), required=True, help='env (for bundles)')
    parser.add_argument('--clusters', nargs='+', required=True, help='names of clusters')
    parser.add_argument('--path', required=True, help='path to bus directory (will be created if does not exist)')
    parser.add_argument('--buses', nargs='+', required=True, help='names of buses')
    args = parser.parse_args()

    for cluster in args.clusters:
        print('Working with', cluster)
        yt.config['proxy']['url'] = cluster
        create_dir(args.path)

        for bus in args.buses:
            bus_config = BUS_CONFIGS.get(bus)
            if bus_config is None:
                raise Exception('Unknown bus name: "{}"'.format(bus))
            table_path = yt.ypath_join(args.path, bus)
            if args.action == 'create':
                create_bus(cluster, table_path, bus_config, args)
            elif args.action == 'remove':
                remove_bus(cluster, table_path, bus_config, args)
            elif args.action == 'alter':
                alter_bus(cluster, table_path, bus_config, args)
            elif args.action == 'remount':
                remount_bus(cluster, table_path)
            elif args.action == 'update-bundles':
                update_bundles(cluster, table_path, args)
            elif args.action == 'update-medium':
                update_medium(cluster, table_path, bus_config)

if __name__ == '__main__':
    main()
