from __future__ import unicode_literals
import logging
import os
import ujson
import socket

import flask
import flask_cors
from gevent import threadpool

from sepelib.core import config
import ydb
from infra.swatlib import rpc
from infra.swatlib import webserver
from infra.swatlib.auth import blackbox
from infra.swatlib.auth import tvm
from infra.swatlib.auth import oauth
from infra.swatlib.auth import passport
from infra.swatlib import metrics
from infra.dproxy.src.api import ydb_logs_service
from infra.dproxy.src.api import yav_service
from infra.dproxy.src.api import migration_service
from infra.dproxy.src.ydb_logs import controller as ydb_logs_controller

VERSION = '0.0.1'


class Ctx(object):
    """
    Web application context passed via flask.g attribute.
    """
    def __init__(self, blackbox_client, rpc_authenticator, tvm_client,
                 ydb_logs_ctl, olap_ydb_logs_ctl, awacs_ydb_logs_ctl, metrics_registry, qloud_ready_list,
                 platform_ready_list, yav_thread_pool):
        self.blackbox_client = blackbox_client
        self.rpc_authenticator = rpc_authenticator
        self.tvm_client = tvm_client
        self.ydb_logs_ctl = ydb_logs_ctl
        self.olap_ydb_logs_ctl = olap_ydb_logs_ctl
        self.awacs_ydb_logs_ctl = awacs_ydb_logs_ctl
        self.metrics_registry = metrics_registry
        self.qloud_ready_list = qloud_ready_list
        self.platform_ready_list = platform_ready_list
        self.yav_thread_pool = yav_thread_pool


class RpcAuthenticator(rpc.authentication.CachingPassportAuthenticator):
    def __init__(self, *args, **kwargs):
        super(RpcAuthenticator, self).__init__(*args, **kwargs)
        self.allowed_oauth_users = None

    def set_allowed_oauth_users(self, users):
        self.allowed_oauth_users = set(users) if users else None

    def authenticate_via_oauth(self, *args, **kwargs):
        subject = super(RpcAuthenticator, self).authenticate_via_oauth(*args, **kwargs)
        if (
            subject is not None
            and self.allowed_oauth_users is not None
            and subject.login not in self.allowed_oauth_users
        ):
            raise rpc.exceptions.ForbiddenError(
                "User {!r} is not allowed via OAuth at this server".format(subject.login)
            )

        return subject


def create_app(ctx, hostname, version, debug=False, origins='*'):
    # Create application
    app = flask.Flask('dproxy')
    # By default Flask sorts keys in json to improve cache ability.
    # Our users don't cache responses based on content,
    # so let's try to improve performance.
    app.config['JSON_SORT_KEYS'] = False

    flask_cors.CORS(
        app,
        origins=origins,
        methods=['GET', 'POST', 'OPTIONS'],
        allow_headers=['Authorization', 'Content-Type', 'If-Match', 'Accept, X-Request-Id'],
        expose_headers=['Etag', 'X-Total-Items, X-Request-Id'],
        supports_credentials=True,
        vary_header=True,
    )

    @app.before_request
    def add_ctx():
        flask.g.ctx = ctx

    @app.before_request
    def handle_options_request():
        if flask.request.method == 'OPTIONS':
            return ''  # skip authentication process

    @app.after_request
    def add_diagnostics_headers(response):
        response.headers['X-Backend-Version'] = version
        response.headers['X-Backend-Host'] = hostname
        return response

    app.hostname = hostname
    app.version = version
    app.logged_headers = ('X-Request-Id',)

    app.debug = debug
    return app


class Application(object):
    name = 'api_server'

    def __init__(self, instance_id):
        self.instance_id = instance_id
        self.log = logging.getLogger(self.name)

        self.server = None
        self.ydb_driver = None
        self.olap_ydb_driver = None

        self.awacs_ydb_driver = None

    @staticmethod
    def setup_environment():
        # Patch requests connection pool to use gevent queue
        from requests.packages.urllib3.connectionpool import ConnectionPool
        from gevent.queue import LifoQueue

        ConnectionPool.QueueCls = LifoQueue
        # Disable requests spamming about:
        # Connection pool is full, discarding connection
        # This is the way we use alemate-http to avoid broken connections
        # There is nothing we can do about it, so simply mute
        logging.getLogger('requests.packages.urllib3.connectionpool').setLevel(logging.ERROR)
        logging.getLogger('ydb.connection').setLevel(logging.ERROR)

    def run(self):
        self.setup_environment()

        # Make app context.
        tvm_client = tvm.TvmClient(
            client_id=config.get_value('tvm.client_id'),
            secret=os.environ.get('TVM_SECRET') or config.get_value('tvm.secret'),
            api_url=config.get_value('tvm.api_url')
        )
        oauth_client = oauth.OAuth.from_config(config.get_value('oauth'))
        passport_client = passport.TvmPassportClient.from_config(
            d=config.get_value('passport'),
            tvm_client=tvm_client
        )
        rpc_authenticator = RpcAuthenticator(
            oauth_client=oauth_client,
            passport_client=passport_client,
            is_auth_disabled=not config.get_value('run.auth')
        )
        rpc_authenticator.set_allowed_oauth_users(config.get_value('oauth.allowed_users', None))
        blackbox_client = blackbox.BlackboxClient(
            url=config.get_value('passport.blackbox_url')
        )

        self.ydb_driver = None
        if config.get_value('ydb_logs.enabled', True):
            ydb_params = ydb.DriverConfig(
                endpoint=config.get_value('ydb_logs.endpoint'),
                database=config.get_value('ydb_logs.db'),
                auth_token=config.get_value('ydb_logs.auth_token')
            )
            self.ydb_driver = ydb.Driver(ydb_params)
            self.ydb_driver.wait(timeout=5)

        self.olap_ydb_driver = None
        if config.get_value('olap_ydb_logs.enabled', True):
            olap_ydb_params = ydb.DriverConfig(
                endpoint=config.get_value('olap_ydb_logs.endpoint'),
                database=config.get_value('olap_ydb_logs.db'),
                auth_token=config.get_value('olap_ydb_logs.auth_token')
            )
            self.olap_ydb_driver = ydb.Driver(olap_ydb_params)
            self.olap_ydb_driver.wait(timeout=5)

        self.awacs_ydb_driver = None
        if config.get_value('awacs_ydb_logs.enabled', True):
            awacs_ydb_params = ydb.DriverConfig(
                endpoint=config.get_value('awacs_ydb_logs.endpoint'),
                database=config.get_value('awacs_ydb_logs.db'),
                auth_token=config.get_value('awacs_ydb_logs.auth_token')
            )
            self.awacs_ydb_driver = ydb.Driver(awacs_ydb_params)
            self.awacs_ydb_driver.wait(timeout=5)

        qloud_ready_list = None
        platform_ready_list = None
        if config.get_value('migration_service.enabled', True):
            with open(config.get_value('migration_service.qloud_coverage'), 'r') as f:
                qloud_ready_list = set(ujson.load(f))
            with open(config.get_value('migration_service.platform_coverage'), 'r') as f:
                platform_ready_list = set(ujson.load(f))

        reg = metrics.Registry()
        ctx = Ctx(
            blackbox_client=blackbox_client,
            rpc_authenticator=rpc_authenticator,
            tvm_client=tvm_client,
            ydb_logs_ctl=self.ydb_driver and ydb_logs_controller.YdbLogsController(
                scheme_client=self.ydb_driver.scheme_client,
                table_client=self.ydb_driver.table_client,
                metrics_registry=reg,
                history_tables_count=config.get_value('ydb_logs.history_tables_count_days', 3),
                pool_size=config.get_value('ydb_logs.ydb_fetching_pool_size', 100),
                request_timeout=config.get_value('ydb_logs.ydb_request_timeout', 90)
            ),
            olap_ydb_logs_ctl=self.olap_ydb_driver and ydb_logs_controller.OlapYdbLogsController(
                scheme_client=self.olap_ydb_driver.scheme_client,
                table_client=self.olap_ydb_driver.table_client,
                metrics_registry=reg,
                history_tables_count=config.get_value('olap_ydb_logs.history_tables_count_days', 3),
                pool_size=config.get_value('olap_ydb_logs.ydb_fetching_pool_size', 100),
                request_timeout=config.get_value('olap_ydb_logs.ydb_request_timeout', 90)
            ),
            awacs_ydb_logs_ctl=self.awacs_ydb_driver and ydb_logs_controller.AwacsYdbLogsController(
                scheme_client=self.awacs_ydb_driver.scheme_client,
                table_client=self.awacs_ydb_driver.table_client,
                metrics_registry=reg,
                history_tables_count=config.get_value('awacs_ydb_logs.history_tables_count_days', 3),
                pool_size=config.get_value('awacs_ydb_logs.ydb_fetching_pool_size', 100),
                request_timeout=config.get_value('awacs_ydb_logs.ydb_request_timeout', 90)
            ),
            metrics_registry=reg,
            qloud_ready_list=qloud_ready_list,
            platform_ready_list=platform_ready_list,
            yav_thread_pool=threadpool.ThreadPool(
                maxsize=config.get_value('yav.thread_pool_size', 5),
            ),
        )

        # Make app.
        app = create_app(ctx,
                         socket.gethostname(),
                         version=VERSION,
                         origins=config.get_value('web.allowed_origins', '*'),
                         debug=True)

        if config.get_value('yav.enabled', True):
            app.register_blueprint(yav_service.yav_service_blueprint)

        if config.get_value('ydb_logs.enabled', True):
            app.register_blueprint(ydb_logs_service.ydb_logs_service_blueprint)

        if config.get_value('olap_ydb_logs.enabled', True):
            app.register_blueprint(ydb_logs_service.olap_ydb_logs_service_blueprint)

        if config.get_value('awacs_ydb_logs.enabled', True):
            app.register_blueprint(ydb_logs_service.awacs_ydb_logs_service_blueprint)

        if config.get_value('migration_service.enabled', True):
            app.register_blueprint(migration_service.yd_migrate_blueprint)

        # Make http server.
        self.server = webserver.WebServer(config.get(), app, version=VERSION, metrics_registry=reg)
        # Run http server forever.
        self.log.info('Starting HTTP server...')
        self.server.run()

    def stop(self):
        self.log.info('Stopping HTTP server...')
        if self.server:
            self.server.stop()
        if self.ydb_driver:
            self.ydb_driver.stop()
        if self.olap_ydb_driver:
            self.olap_ydb_driver.stop()
        if self.awacs_ydb_driver:
            self.awacs_ydb_driver.stop()
        self.log.info('Done.')
