from functools import wraps
import logging
from contextlib import contextmanager
from datetime import datetime
from json import loads

from psycopg2 import IntegrityError, OperationalError

from mail.husky.husky import tasks
from mail.pypg.pypg.query_handler import ExpectOneItemError
from dateutil.tz import tzlocal as get_localzone

from werkzeug import routing
from werkzeug.exceptions import HTTPException, NotFound
from werkzeug.routing import Rule, BaseConverter
from werkzeug.wrappers import Request, Response

from ora2pg.app import make_app_from_env
from ora2pg.tools.find_master_helpers import find_huskydb, get_huskydb_pooled_conn

from .api_queries import ApiQueries
from .sharddb import SharddbAdaptor, GetUserShardIdError
from .types import Task, Status
from .web_tools import json_response, jdumps
from .web_tools.middleware import make_error, make_custom_error
from mail.python.tvm_requests import Tvm, TvmCheckError


task_types = set(i for i in Task.__dict__.values() if type(i) == str)
log = logging.getLogger(__name__)


def transfer_info_as_dict(transfer_info):
    return {
        'status': transfer_info.status,
        'error_type': transfer_info.error_type,
        'tries': transfer_info.tries,
        'try_notices': transfer_info.try_notices,
        'last_update': transfer_info.last_update,
        'transfer_id': transfer_info.transfer_id,
    }


def handle_delete_user_exceptions(func):
    def impl(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except IntegrityError as e:
            return Response(
                response=jdumps(make_error(exception=e, traceback=None)),
                status=400,
                content_type='application/json',
            )
        except GetUserShardIdError as e:
            return Response(
                response=jdumps(make_error(exception=e, traceback=None)),
                status=404,
                content_type='application/json',
            )
        except:
            raise
    return impl


def check_tvm(func):
    @wraps(func)
    def impl(self, request, *args, **kwargs):
        def check():
            if self.tvm_client_ids is not None:
                ticket_info = request.environ.get('tvm_ticket')
                if not ticket_info:
                    log.error('No TVM ticket prodived')
                    raise TvmCheckError('No TVM ticket prodived')
                if ticket_info.src not in self.tvm_client_ids:
                    log.error('Wrong src in ticket: %s', ticket_info.logging_string)
                    raise TvmCheckError('Wrong src=%d in ticket' % ticket_info.src)
            else:
                log.info('TVM check skipped cause no tvm_client_ids are provided')

        try:
            check()
            return func(self, request, *args, **kwargs)
        except TvmCheckError as e:
            return Response(
                response=str(e),
                status=401,
                content_type='text/plain',
            )
    return impl


class NowConverter(BaseConverter):
    def __init__(self, url_map):
        super(NowConverter, self).__init__(url_map)
        self.regex = '(?:right_now)'

    def to_python(self, value):
        return value == 'right_now'

    def to_url(self, value):
        return 'right_now' if value else None


class WithDeletedConverter(BaseConverter):

    def __init__(self, url_map):
        super(WithDeletedConverter, self).__init__(url_map)
        self.regex = '(?:with_deleted_box)'

    def to_python(self, value):
        return value == 'with_deleted_box'

    def to_url(self, value):
        return 'with_deleted_box' if value else None


class HuskyApiApp(object):
    url_map = routing.Map([
        Rule('/', endpoint='describe'),
        Rule('/ping', endpoint='ping'),
        Rule('/pingdb', endpoint='pingdb'),
        Rule('/clone_user/<source_uid>/into/<dest_uid>',
             endpoint='add_clone_user_task',
             methods=['POST', 'PUT']),
        Rule('/clone_user/<source_uid>/into/<dest_uid>/<with_deleted_bool:with_deleted_box>',
             endpoint='add_clone_user_task',
             methods=['POST', 'PUT']),
        Rule('/clone_user/<int:source_uid>/into/<int:dest_uid>',
             endpoint='clone_user_status',
             methods=['GET']),
        Rule('/delete_user/<uid>',
             endpoint='delete_user',
             methods=['POST', 'PUT']),
        Rule('/delete_user/<uid>/<now_bool:right_now>',
             endpoint='delete_user',
             methods=['POST', 'PUT']),
        Rule('/add_task',
             methods=['POST'], endpoint="add_task"),
        Rule('/add_shard_task',
             methods=['POST'], endpoint="add_shard_task"),
    ], converters={
        'now_bool': NowConverter,
        'with_deleted_bool': WithDeletedConverter})

    def __init__(self, app, config, tvm_conf=None):
        assert 'delete_user_shift' in config
        self.config = config
        self.app = app
        self.tvm_client_ids = tvm_conf and tvm_conf.get('src_client_ids', None)

    @staticmethod
    def from_conf(config, tvm_conf):
        app_kwargs = dict(verbose=config.get('verbose', False),
                          log_filepath=config.get('log_filepath', None))

        app = make_app_from_env(config['env'], **app_kwargs)
        return HuskyApiApp(app, config, tvm_conf)

    @contextmanager
    def _connection(self):
        dsn = find_huskydb(self.app.args)
        with get_huskydb_pooled_conn(dsn) as conn:
            yield conn

    @handle_delete_user_exceptions
    @check_tvm
    @json_response
    def on_add_task(self, request):
        json_data = loads(request.data.decode('utf-8'))
        uid = json_data['uid']
        try:
            task = tasks.get_handler(json_data['task'])(
                app=self.app,
                transfer_id=None,
                uid=uid,
                task_args=json_data['task_args'],
            )
        except tasks.errors.NotSupportedError:
            return {"status": "error", "error": "unknown task type"}
        shard_id = task.loaded_shard_id or SharddbAdaptor(self.app.args).get_user_shard_id(uid)
        with self._connection() as conn:
            q = ApiQueries(conn)
            try:
                transfer_info = q.add_task(
                    uid=uid,
                    task=json_data['task'],
                    status='pending',
                    shard_id=shard_id,
                    priority=self.config.get(json_data['task'] + "_priority", 0),
                    task_args=jdumps(json_data['task_args']),
                )
            except BaseException as exp:
                return {"status" : "error", "error" : str(exp)}
        return {'status' : 'ok', 'task' : transfer_info_as_dict(transfer_info)}

    @handle_delete_user_exceptions
    @check_tvm
    @json_response
    def on_add_shard_task(self, request):
        sharddb = SharddbAdaptor(self.app.args)

        json_data = loads(request.data.decode('utf-8'))
        task_name = json_data['task']
        shard_name = json_data['shard_name']
        task_args = json_data.get('task_args', {})
        with_deleted = json_data.get('with_deleted', False)

        tasks.get_handler(task_name)(
            app=self.app,
            transfer_id=None,
            uid=0,
            task_args=task_args,
        )

        loaded_shard_id = json_data.get('loaded_shard_id', None) or sharddb.get_shard_id(shard_name)

        tasks_count = 0
        users_count = 0
        with self._connection() as conn:
            try:
                huskydb = ApiQueries(conn)
                for uids in sharddb.get_shard_users(shard_name, with_deleted):
                    users_count += len(uids)
                    tasks_count += huskydb.add_task_for_users(
                        uids=uids,
                        task=task_name,
                        shard_id=loaded_shard_id,
                        priority=json_data.get('priority', 0),
                        task_args=jdumps(task_args),
                    )
            except BaseException:
                log.exception("add_shard_task is faild!\nReason:")
                conn.rollback()
                raise
        return {'status' : 'ok', 'tasks_count' : tasks_count, 'users_count' : users_count}

    def on_describe(self, request):
        return Response(str(self.url_map))

    def on_ping(self, request):
        return Response('pong')

    def on_pingdb(self, request):
        try:
            with self._connection() as conn:
                self._ping_query(conn)
                return Response(response='ok', status=200)
        except OperationalError as outer:
            log.exception("pingdb is faild!\nReason:")
            return Response(response=str(outer), status=500)

    def _ping_query(self, conn):
        cursor = conn.cursor()
        cursor.execute("select 1")

    def _check_task_args(self, task_args):
        for arg in task_args:
            if arg not in ['min_received_date', 'max_received_date']:
                return False
        return True

    @check_tvm
    @json_response
    def on_add_clone_user_task(self, request, source_uid, dest_uid, with_deleted_box=False):
        with self._connection() as conn:
            q = ApiQueries(conn)
            try:
                old_audit = q.get_clone_audit(
                    source_uid=source_uid,
                    dest_uid=dest_uid
                )
                old_task = {}
                try:
                    old_transfer_info = q.get_task(
                        transfer_id=old_audit.transfer_id
                    )
                    old_task = transfer_info_as_dict(old_transfer_info)
                except ExpectOneItemError:
                    pass
                return {
                    'status': 'already_exists',
                    'task': old_task,
                }
            except ExpectOneItemError:
                pass
            shard_id = SharddbAdaptor(
                self.app.args
            ).resolve_endpoint(
                self.config['clone_user_dest_shard']
            ).shard_id

            # TODO :: add user in sharddb.users synchronously
            # that way any mail stores made after /clone_user will stay in queue
            # instead of triggering lazy user initialization via Sharpei

            task_args = {}
            if request.data:
                task_args = loads(request.data.decode('utf-8'))
            if not self._check_task_args(task_args):
                return {
                    'status': 'error',
                    'reason': 'unknown task argument',
                }
            task_args.update({
                'source_user_uid': source_uid,
                'dest_shard': self.config['clone_user_dest_shard'],
                'restore_deleted_box': with_deleted_box,
            })
            new_transfer_info = q.add_task(
                uid=dest_uid,
                task='clone_user',
                status='pending',
                shard_id=shard_id,
                priority=self.config.get('clone_user_priority', 0),
                task_args=jdumps(task_args)
            )

            q.add_clone_audit(
                source_uid=source_uid,
                dest_uid=dest_uid,
                transfer_id=new_transfer_info.transfer_id,
                request_info=jdumps({
                    'user_agent': request.user_agent.string,
                    'task_args': task_args
                })
            )

            return {
                'status': 'ok',
                'task': transfer_info_as_dict(new_transfer_info),
            }

    @check_tvm
    @json_response
    def on_clone_user_status(self, request, source_uid, dest_uid):
        with self._connection() as conn:
            q = ApiQueries(conn)
            try:
                audit_row = q.get_clone_audit(
                    source_uid=source_uid,
                    dest_uid=dest_uid
                )
            except ExpectOneItemError:
                return {
                    'status': 'not_found',
                    'reason': 'unknown clone_user task',
                }
            try:
                transfer_info = q.get_task(
                    transfer_id=audit_row.transfer_id
                )
            except ExpectOneItemError:
                if audit_row.complete:
                    return {
                        'status': 'complete',
                        'task': {
                            'transfer_id': audit_row.transfer_id,
                        }
                    }
                else:
                    return {
                        'status': 'error',
                        'reason': 'task does not exist, '
                                  'and clone does not complete '
                                  'probably already broken '
                                  'and queue was cleaned',
                        'task': {
                            'transfer_id': audit_row.transfer_id,
                        }
                    }

            return {
                'status': transfer_info.status,
                'task': transfer_info_as_dict(transfer_info),
            }

    @handle_delete_user_exceptions
    @check_tvm
    def on_delete_user(self, request, uid, right_now=False):
        localzone = get_localzone()
        if right_now:
            deleted_date = (
                datetime.fromtimestamp(0, localzone)
            ).isoformat()
        else:
            deleted_date = (
                datetime.now(tz=localzone) + self.config['delete_user_shift']
            ).isoformat()

        with self._connection() as conn:
            q = ApiQueries(conn)
            existing_delete_mail_user = q.get_task_by_user_and_type(
                uid=uid,
                task=Task.DeleteMailUser,
            )
            existing_delete_shards_user = q.get_task_by_user_and_type(
                uid=uid,
                task=Task.DeleteShardsUser,
            )
            if existing_delete_mail_user or existing_delete_shards_user:
                log.warning('Found existing delete user tasks: delete_mail_user=%s delete_shards_user=%s',
                            existing_delete_mail_user, existing_delete_shards_user)
                return Response(
                    response=jdumps(make_custom_error(code=4, message='Tasks already exists')),
                    status=400,
                    content_type='application/json',
                )
            shard_id = SharddbAdaptor(self.app.args).get_user_shard_id(uid)
            delete_shards_user = transfer_info_as_dict(q.add_task(
                uid=uid,
                task=Task.DeleteShardsUser,
                status=Status.Pending,
                shard_id=shard_id,
                priority=self.config.get('delete_shards_user_priority', 0),
                task_args=jdumps(dict()),
            ))
            delete_shards_user['name'] = Task.DeleteShardsUser
            delete_mail_user = transfer_info_as_dict(q.add_task(
                uid=uid,
                task=Task.DeleteMailUser,
                status=Status.Pending,
                shard_id=shard_id,
                priority=self.config.get('delete_mail_user_priority', 0),
                task_args=jdumps(dict(
                    deleted_date=deleted_date,
                )),
            ))
            delete_mail_user['name'] = Task.DeleteMailUser
            return Response(
                response=jdumps(dict(
                    status='ok',
                    tasks=[delete_mail_user, delete_shards_user],
                )),
                status=200,
                content_type='application/json',
            )

    @staticmethod
    def error_404():
        response = Response(
            'This page Not Found',
            mimetype='text/plain')
        response.status_code = 404
        return response

    def dispatch_request(self, request):
        adapter = self.url_map.bind_to_environ(request.environ)
        try:
            endpoint, values = adapter.match()
            return getattr(self, 'on_' + endpoint)(request, **values)
        except NotFound:
            return self.error_404()
        except HTTPException as e:
            return e

    def __call__(self, environ, start_response):
        request = Request(environ)
        response = self.dispatch_request(request)
        return response(environ, start_response)


def create_app(config):
    import os
    from .web_tools.middleware import ExceptionMiddleware, TvmMiddleware

    tvm_conf = config.get('tvm', None)
    app = HuskyApiApp.from_conf(config, tvm_conf)
    if tvm_conf:
        logging.info('TVM checks are enabled')
        tvm = Tvm(
            tvm_daemon_url=tvm_conf.get('daemon_url', 'http://localhost:1'),
            client_id=tvm_conf['client_id'],
            local_token=tvm_conf.get('local_token', None)
        )
        app = TvmMiddleware(app, tvm)
    else:
        logging.info('TVM checks are disabled')
    app = ExceptionMiddleware(app)
    # disable pin in debugger
    os.environ['WERKZEUG_DEBUG_PIN'] = 'off'
    return app
