# -*- coding: utf-8 -*-
from collections import defaultdict
from datetime import (
    datetime,
    timedelta,
)
import logging
import json

from passport.backend.core.db.utils import insert_with_on_duplicate_key_update_if_equals
from passport.backend.utils.string import smart_str
from passport.backend.utils.time import (
    datetime_to_string,
    get_unixtime,
    unixtime_to_datetime,
)
from passport.infra.daemons.yasmsapi.db.config import (
    DB_QUEUE,
    DB_VALIDATOR,
)
from passport.infra.daemons.yasmsapi.db.connection import (
    DBError,
    get_db_connection,
)
from passport.infra.daemons.yasmsapi.db.schemas import (
    blocked_phones_table,
    BLOCKING_TYPE_PERMANENT,
    daemon_heartbeat_table,
    queue_metadata,
    sms_gates_table,
    sms_metadata,
    sms_queue_table,
    sms_routes_table,
    STATUS_NOT_SENT,
    STATUS_READY,
)
from sqlalchemy import (
    and_,
    insert,
    join,
    select,
    update,
)


log = logging.getLogger('yasms.db.queries')

ROUTES_CACHE = []
ROUTES_CACHE_UPDATED = 0
ROUTES_CACHE_LIFETIME = 600

GATES_CACHE = []
GATES_CACHE_UPDATED = 0
GATES_CACHE_LIFETIME = 60

METADATA_TO_DB_NAME = {
    queue_metadata: DB_QUEUE,
    sms_metadata: DB_VALIDATOR,
}


def check_phone_blocked(number):
    """
    :rtype: bool
    """
    db_table = blocked_phones_table
    query = select([db_table]).where(
        and_(
            db_table.c.phone == smart_str(number),
            db_table.c.blocktype == BLOCKING_TYPE_PERMANENT,
            db_table.c.blocktill > datetime.now(),
        ),
    )
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        result = get_db_connection().execute(db_name, query)
        phone_blocked = result.fetchall()
        return bool(phone_blocked)
    except DBError as e:
        log.error('Couldn\'t SELECT FROM blockedphones: %s', e)
        return False


def get_all_gates():
    """
    :rtype: list
    """
    db_table = sms_gates_table
    query = select([db_table])
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        result = get_db_connection().execute(db_name, query)
        gates = result.fetchall()
        return [dict(gate) for gate in gates]
    except DBError as e:
        log.error('Cannot SELECT FROM smsgates: %s', e)
        return []


def get_gate_by_id(gate_id):
    """
    :rtype: dict
    """
    db_table = sms_gates_table
    query = select([db_table]).where(db_table.c.gateid == gate_id)
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        result = get_db_connection().execute(db_name, query)
        gate = result.fetchone()
        return dict(gate) if gate else {}
    except DBError as e:
        log.error('Cannot get gate by id: %s', e)
        return {}


def get_first_gate_by_aliase_and_service(aliase, service):
    """
    :rtype: dict
    """
    db_table = sms_gates_table
    query = select([db_table]).where(
        and_(
            db_table.c.aliase == aliase,
            db_table.c.fromname == service,
        ),
    )
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        result = get_db_connection().execute(db_name, query)
        gates = result.fetchall()
        if not gates:
            return {}
        return dict(gates[0])
    except DBError as e:
        log.error('Cannot get gate by aliase and service: %s', e)
        return {}


def get_possible_routes(number, rule='default'):
    """
    :rtype: list
    """
    global ROUTES_CACHE
    global ROUTES_CACHE_UPDATED

    now = get_unixtime()
    if now - ROUTES_CACHE_UPDATED > ROUTES_CACHE_LIFETIME:
        db_gates_table = sms_gates_table
        db_routes_table = sms_routes_table

        join_query = join(
            db_routes_table,
            db_gates_table,
            db_routes_table.c.gateid == db_gates_table.c.gateid,
        )
        query = select(
            [
                db_routes_table.c.mode,
                db_routes_table.c.destination.label('prefix'),
                db_routes_table.c.gateid2,
                db_routes_table.c.gateid3,
                db_routes_table.c.gateid,
                db_routes_table.c.weight,
                db_gates_table.c.aliase,
            ]
        ).select_from(join_query)

        db_name = METADATA_TO_DB_NAME[db_gates_table.metadata]
        try:
            result = get_db_connection().execute(db_name, query)
            result = result.fetchall()
            # преобразуем в питонячий словарь из ResultProxy
            ROUTES_CACHE = [dict(route) for route in result]
            ROUTES_CACHE_UPDATED = now

            log.debug(
                'routes cache updated at %s',
                datetime_to_string(unixtime_to_datetime(now)),
            )
        except DBError as e:
            log.error('Cannot SELECT FROM smsrt: %s', e)
            return []

    routes = ROUTES_CACHE
    routes_by_rule = [route for route in routes if route['mode'] == rule]
    if not routes_by_rule:
        routes_by_rule = [route for route in routes if route['mode'] == 'default']

    # по длине префикса от самого длинного к самому общему
    sorted_routes = sorted(routes_by_rule, key=lambda r: len(r['prefix']), reverse=True)

    routes_by_prefix = defaultdict(list)

    for route in sorted_routes:
        routes_by_prefix[route['prefix']].append(route)

    first_matched_route = next(
        (route for route in sorted_routes if number.startswith(route['prefix'])),
        None,
    )

    if not first_matched_route:
        log.error('Cannot get route for number: %s', number)
        return []

    possible_routes = routes_by_prefix[first_matched_route['prefix']]
    return possible_routes


def enqueue_sms(number, gate_id, text, sender, metadata=None):
    """
    :rtype: int
    """
    db_table = sms_queue_table
    now = datetime.now()

    values = {
        'phone': number,
        'gateid': gate_id,
        'text': smart_str(text),
        'sender': sender,
        'create_time': now,
        'touch_time': now,
        'metadata': json.dumps(metadata, separators=(',', ':')) if metadata else '',
    }
    query = insert(db_table).values(values)
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        result = get_db_connection().execute(db_name, query)
        sms_id = result.inserted_primary_key
        return sms_id[0]
    except DBError as e:
        log.error('Cannot INSERT INTO smsqueue_anonym: %s', e)
        raise


def get_all_routes():
    """
    :rtype: list
    """
    db_table = sms_routes_table
    query = select([db_table])
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        result = get_db_connection().execute(db_name, query)
        routes = result.fetchall()
        return [dict(route) for route in routes]
    except DBError as e:
        log.error('Cannot SELECT FROM smsrt: %s', e)
        return []


def load_gates():
    global GATES_CACHE
    global GATES_CACHE_UPDATED

    now = get_unixtime()

    if not GATES_CACHE or now - GATES_CACHE_UPDATED > GATES_CACHE_LIFETIME:
        db_table = sms_gates_table
        query = select(
            [
                db_table.c.gateid,
                db_table.c.aliase,
                db_table.c.fromname,
            ]
        )
        db_name = METADATA_TO_DB_NAME[db_table.metadata]
        try:
            result = get_db_connection().execute(db_name, query)
            gates = result.fetchall()
            GATES_CACHE = [dict(gate) for gate in gates]
            GATES_CACHE_UPDATED = now
            log.debug('Gates cache updated at %s' % now)
        except DBError as e:
            log.error('Loading gates failed; reason=%s', e)
            return []

    return GATES_CACHE


def search_sms_query(limit=1):
    db_table = sms_queue_table
    query = (
        select(
            [
                db_table.c.smsid,
                db_table.c.phone,
                db_table.c.text,
                db_table.c.gateid,
                db_table.c.sender,
                db_table.c.errors,
            ]
        )
        .where(
            and_(
                db_table.c.status == STATUS_READY,
                db_table.c.touch_time <= datetime.now(),
            ),
        )
        .order_by(db_table.c.touch_time)
        .limit(limit)
    )
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        result = get_db_connection().execute(db_name, query)
        messages = result.fetchall()
        return [dict(sms) for sms in messages]
    except DBError as e:
        log.error('Loading messages failed; reason=%s', e)
        raise


def spoil_sms_query(sms_id, message):
    db_table = sms_queue_table
    now = datetime.now()

    values = {
        'status': STATUS_NOT_SENT,
        'errors': db_table.c.errors + 1,
        'touch_time': now,
        'dlrmessage': message,
    }
    query = update(db_table).where(db_table.c.smsid == sms_id).values(values)
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        get_db_connection().execute(db_name, query)
    except DBError as e:
        log.error(
            'Spoiling sms failed; smsid={smsid}; reason={reason}'.format(
                smsid=sms_id,
                reason=e,
            )
        )
        raise


def suspend_sms_query(sms_id, interval):
    db_table = sms_queue_table
    suspend_until = datetime.now() + timedelta(seconds=interval)

    values = {
        'status': STATUS_READY,
        'errors': db_table.c.errors + 1,
        'touch_time': suspend_until,
    }
    query = update(db_table).where(db_table.c.smsid == sms_id).values(values)
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        get_db_connection().execute(db_name, query)
    except DBError as e:
        log.error('Suspending sms failed; smsid=%s; reason=%s', sms_id, e)
        raise


def pull_sms_query(sms_id, status):
    db_table = sms_queue_table
    now = datetime.now()

    values = {
        'status': status,
        'touch_time': now,
    }
    query = update(db_table).where(db_table.c.smsid == sms_id).values(values)
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        get_db_connection().execute(db_name, query)
    except DBError as e:
        log.error('Pulling sms failed; smsid=%s; reason=%s', sms_id, e)
        raise


def get_heartbeat(host_name):
    db_table = daemon_heartbeat_table
    query = select([db_table]).where(db_table.c.hostname == host_name)
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        result = get_db_connection().execute(db_name, query)
        last_heartbeat = result.fetchone()
        return dict(last_heartbeat).get('beat_time', None)
    except DBError as e:
        log.error('Checking heartbeat failed for host %s; reason=%s', host_name, e)
        raise


def update_heartbeat(host_name):
    db_table = daemon_heartbeat_table
    now = datetime.now()

    values = {
        'hostname': host_name,
        'beat_time': now,
    }
    query = insert_with_on_duplicate_key_update_if_equals(
        table=db_table,
        args=['beat_time', 'hostname'],
        key_name='hostname',
        key_value=host_name,
        else_null=False,
    ).values(values)
    db_name = METADATA_TO_DB_NAME[db_table.metadata]
    try:
        get_db_connection().execute(db_name, query)
    except DBError as e:
        log.error(
            'Making heartbeat failed; hostname={hostname}; time={time}; reason={reason}'.format(
                hostname=host_name,
                time=now.strftime('%Y-%m-%d %H:%M:%S'),
                reason=e,
            )
        )
        raise
