# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

import json
import logging
import re
from collections import OrderedDict

import pymongo
import pytz
from django.conf import settings
from enum import Enum
from marshmallow import Schema, fields, post_load
from rest_framework import status
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.response import Response

from common.db.mongo import ConnectionProxy
from travel.rasp.train_api.helpers.ydb import ydb_provider
from travel.rasp.train_api.train_purchase.backoffice.base import MongoCursorLimitOffsetPagination, BackofficeAdminViewSet
from travel.rasp.train_api.train_purchase.utils.logs import YDB_ORDER_LOGS_TABLE_NAME

log = logging.getLogger(__name__)
database = ConnectionProxy('train_purchase')


class OrderLogsQueryMode(Enum):
    MONGO_ONLY = 'mongo'
    YDB_WITH_FALLBACK = 'ydb_with_fallback'
    YDB_ONLY = 'ydb'


def _format_log_time(dt):
    if dt.tzinfo is not None:
        dt = dt.astimezone(pytz.UTC).replace(tzinfo=None)

    return dt.isoformat() + 'Z'


class OrderLogsQuery(object):
    def __init__(self, data):
        self.data = data

    def build_query(self):
        query = {'filter': {}}
        if self.data.get('time_from'):
            query['filter']['created_utc'] = {'$gte': _format_log_time(self.data['time_from'])}
        if self.data.get('time_to'):
            query['filter']['created_utc'] = {'$lte': _format_log_time(self.data['time_to'])}
        if self.data.get('order_uid'):
            query['filter']['context.order_uid'] = self.data['order_uid']
        if self.data.get('name'):
            name = self.data['name']
            if name.endswith('.'):
                query['filter']['name'] = {
                    '$regex': re.compile(
                        r'^{}'.format(re.escape(name)),
                        re.U
                    )
                }
            else:
                query['filter']['name'] = {
                    '$regex': re.compile(
                        r'^{}(?:$|\.)'.format(re.escape(name)),
                        re.U
                    )
                }
        if self.data.get('levelname'):
            query['filter']['levelname'] = self.data['levelname']
        if self.data.get('message'):
            query['filter']['$or'] = [
                {'message': {'$regex': re.compile(r'.*{}.*'.format(re.escape(self.data['message'])), re.U | re.I)}},
                {'context.process_name': self.data['message']},
            ]

        if self.data.get('reversed'):
            query['sort'] = [('created_utc', pymongo.DESCENDING)]
        else:
            query['sort'] = [('created_utc', pymongo.ASCENDING)]

        return query

    def build_ydb_query(self, database=settings.YDB_DATABASE, table_name=YDB_ORDER_LOGS_TABLE_NAME,
                        limit=None, offset=0):
        declarations = set()
        filters = []
        query_params = {}
        query_limit = ''
        if self.data.get('time_from'):
            declarations.add('DECLARE $time_from AS "Utf8";')
            filters.append('created_utc >= $time_from')
            query_params['$time_from'] = self.data['time_from']
        if self.data.get('time_to'):
            declarations.add('DECLARE $time_to AS "Utf8";')
            filters.append('created_utc <= $time_to')
            query_params['$time_to'] = self.data['time_to']
        if self.data.get('order_uid'):
            declarations.add('DECLARE $order_uid AS "Utf8";')
            filters.append('order_uid = $order_uid')
            query_params['$order_uid'] = self.data['order_uid']
        if self.data.get('name'):
            declarations.add('DECLARE $name AS "Utf8";')
            filters.append('name LIKE $name')
            query_params['$name'] = '{}%'.format(self.data['name'])
        if self.data.get('levelname'):
            declarations.add('DECLARE $levelname AS "Utf8";')
            filters.append('levelname LIKE $levelname')
            query_params['$levelname'] = self.data['levelname']
        if self.data.get('message'):
            declarations.add('DECLARE $message AS "Utf8";')
            filters.append('(message LIKE $message OR cast(context as Utf8) LIKE $message)')
            query_params['$message'] = '%{}%'.format(self.data['message'])
        if limit:
            declarations.add('DECLARE $limit AS "Int32";')
            declarations.add('DECLARE $offset AS "Int32";')
            query_params['$limit'] = limit
            query_params['$offset'] = offset
            query_limit = 'LIMIT $limit OFFSET $offset'
        query_order = 'created_utc DESC' if self.data.get('reversed') else 'created_utc'
        query_filter = '\n AND '.join(filters)
        query_declarations = '\n'.join(declarations)
        query_string = """
{declarations}
PRAGMA TablePathPrefix("{database}");

SELECT Count(1) as count
FROM {table_name}
WHERE {filter};

SELECT created_utc, name, levelname, message, created, [process], context, exception
FROM {table_name}
WHERE {filter}
ORDER BY {order}
{limit};
""".format(database=database, table_name=table_name, declarations=query_declarations, filter=query_filter,
           order=query_order, limit=query_limit)
        return query_string, query_params


class OrderLogsQuerySchema(Schema):
    time_from = fields.DateTime(load_from='timeFrom')
    time_to = fields.DateTime(load_from='timeTo')
    name = fields.String()
    levelname = fields.String()
    message = fields.String()
    order_uid = fields.String(load_from='orderUID')
    reversed = fields.Boolean(missing=False)
    mode = fields.String(missing=OrderLogsQueryMode.MONGO_ONLY.value)

    @post_load
    def build_query(self, data):
        return OrderLogsQuery(data)


class OrderLogsViewSet(BackofficeAdminViewSet):
    _paginator = None
    _ydb_paginator = None

    @property
    def paginator(self):
        if not self._paginator:
            self._paginator = MongoCursorLimitOffsetPagination()
            self._paginator.default_limit = 100

        return self._paginator

    @property
    def ydb_paginator(self):
        if not self._ydb_paginator:
            self._ydb_paginator = LimitOffsetPagination()
            self._ydb_paginator.default_limit = 100

        return self._ydb_paginator

    def list(self, request):
        query, errors = OrderLogsQuerySchema().load(request.query_params)
        if errors:
            return Response({
                'errors': errors,
            }, status=status.HTTP_400_BAD_REQUEST)
        mode = query.data['mode']
        if mode == OrderLogsQueryMode.MONGO_ONLY.value:
            return self.get_from_mongo(query, request)
        elif mode == OrderLogsQueryMode.YDB_ONLY.value:
            return self.get_from_ydb(query, request)
        elif mode == OrderLogsQueryMode.YDB_WITH_FALLBACK.value:
            ydb_result = self.get_from_ydb(query, request)
            if len(ydb_result.data['results']) < self.ydb_paginator.limit:
                mongo_result = self.get_from_mongo(query, request)
                if len(mongo_result.data['results']) > len(ydb_result.data['results']):
                    return mongo_result
            return ydb_result
        else:
            return Response({
                'errors': {'mode': 'Unknown mode'},
            }, status=status.HTTP_400_BAD_REQUEST)

    def get_from_mongo(self, query, request):
        cursor = database.order_logs.find(**query.build_query())
        data = self.paginator.paginate_queryset(cursor, request)
        for d in data:
            d['_id'] = str(d['_id'])

        return self.paginator.get_paginated_response(data)

    def get_from_ydb(self, query, request):
        paginator = self.ydb_paginator
        paginator.request = request
        paginator.limit = paginator.get_limit(request)
        paginator.offset = paginator.get_offset(request)
        query_string, query_params = query.build_ydb_query(
            limit=paginator.limit,
            offset=paginator.offset
        )
        session = ydb_provider.get_session()
        prepared_query = session.prepare(query_string)
        result_sets = session.transaction().execute(
            prepared_query, query_params, commit_tx=True,
        )
        paginator.count = result_sets[0].rows[0]['count']
        data = list(result_sets[1].rows)
        for row in data:
            context = row.get('context')
            exception = row.get('exception')
            row['context'] = json.loads(context) if context else {}
            row['exception'] = json.loads(exception) if exception else None
            row['context']['log_source'] = 'ydb'
        return Response(OrderedDict([
            ('count', paginator.count),
            ('next', paginator.get_next_link()),
            ('previous', paginator.get_previous_link()),
            ('results', data)
        ]))
