# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import enum

from sqlalchemy import inspect, event, insert
from sqlalchemy import Column, func, Index, PrimaryKeyConstraint, UniqueConstraint, \
    Text, DateTime, ForeignKey, Enum, Integer, Float, Boolean
from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, JSONB

from travel.rasp.bus.db.models.shared import Base


class PointType(enum.Enum):
    CITY = (0, 'c')
    STATION = (1, 's')
    INVALID = (2, None)

    def __new__(cls, value, point_key_prefix):
        obj = object.__new__(cls)
        obj._value_ = value
        obj.point_key_prefix = point_key_prefix
        return obj


def parse_point_key(point_key):
    if point_key:
        for point_type in (PointType.CITY, PointType.STATION):
            if point_key.startswith(point_type.point_key_prefix):
                try:
                    return point_type, int(point_key[1:])
                except ValueError:
                    break
    return PointType.INVALID, None


class PointMatching(Base):
    __tablename__ = 'point_matching'

    id = Column(Integer, primary_key=True)

    supplier_id = Column(ForeignKey('buses.supplier.id'), nullable=False)
    supplier = relationship('Supplier', foreign_keys=[supplier_id], backref='point_matchings')

    supplier_point_id = Column(Text, nullable=False)

    parent_id = Column(ForeignKey('point_matching.point_matching.id'))
    parent = relationship('PointMatching', foreign_keys=[parent_id])

    type = Column(Enum(PointType))
    title = Column(Text, nullable=False)
    description = Column(Text)
    latitude = Column(Float)
    longitude = Column(Float)
    country = Column(Text)
    point_key = Column(Text)

    city_id = Column(Text)
    country_code = Column(Text)
    city_title = Column(Text)
    region = Column(Text)
    region_code = Column(Text)
    district = Column(Text)
    extra_info = Column(Text)
    timezone_info = Column(Text)

    disabled = Column(Boolean, default=False)
    outdated = Column(Boolean, default=False)
    in_segments = Column(Boolean, default=None)

    created_at = Column(DateTime, nullable=False, server_default=func.now())
    updated_at = Column(DateTime, nullable=False, server_default=func.now(), onupdate=func.now())

    updated_by_login = Column(ForeignKey('buses.admin_users.login'), nullable=False)
    updated_by = relationship('AdminUser', foreign_keys=[updated_by_login])

    __table_args__ = (
        UniqueConstraint('supplier_id', 'supplier_point_id', name='point_matching_ukey'),
        Index('point_matching_point_key_idx', point_key),
        {
            'schema': 'point_matching',
        },
    )


class Matching(Base):
    """
    Model Matching is deprecated. Use PointMatching instead.
    """

    __tablename__ = 'matching'

    supplier = Column(Text, primary_key=True)
    supplier_id = Column(Text, primary_key=True)
    rasp_id = Column(Text)

    __table_args__ = (
        PrimaryKeyConstraint(supplier, supplier_id, name='matching_pkey'),
        Index('matching_rasp_id_idx', rasp_id.desc().nullslast()),
        {
            'schema': 'matching',
        },
    )


class Endpoint(Base):
    """
    Model Endpoint is deprecated. Use PointMatching instead.
    """

    __tablename__ = 'endpoints'

    supplier = Column(Text, nullable=False, primary_key=True)
    supplier_id = Column(Text, primary_key=True)
    type = Column(Text, nullable=False)
    parent_id = Column(Text)
    title = Column(Text, nullable=False)
    description = Column(Text)
    latitude = Column(DOUBLE_PRECISION)
    longitude = Column(DOUBLE_PRECISION)
    country = Column(Text)

    __table_args__ = (
        PrimaryKeyConstraint(supplier, supplier_id, name='endpoints_pkey'),
        Index('description_lower_supplier_supplier_id', func.lower(description), supplier, supplier_id),
        Index('endpoints_supplier_id_type_idx', type, supplier_id),
        {
            'schema': 'matching',
        },
    )


class MatchingChange(Base):
    __tablename__ = 'matching_changes'

    id = Column(Integer, primary_key=True)
    supplier_id = Column(ForeignKey('buses.supplier.id'), nullable=False)
    supplier = relationship('Supplier', foreign_keys=[supplier_id])

    point_key = Column(Text, nullable=False)
    updated_at = Column(DateTime, nullable=False, server_default=func.now())

    __table_args__ = (
        PrimaryKeyConstraint(id, name='matching_changes_pkey'),
        {
            'schema': 'point_matching',
        },
    )


class PointMatchingLog(Base):
    __tablename__ = 'point_matching_logs'

    id = Column(Integer, primary_key=True)
    point_matching_id = Column(ForeignKey('point_matching.point_matching.id'), nullable=False)
    point_matching = relationship('PointMatching', foreign_keys=[point_matching_id], backref='point_matching_logs')
    action = Column(Text)
    changes = Column(JSONB, nullable=False)
    log_time = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
    author = Column(Text, nullable=False)

    __table_args__ = (
        Index('point_matching_logs_pm_id_idx', point_matching_id),
        {
            'schema': 'point_matching',
        },
    )


LOG_FIELDS_TO_SKIP = ('updated_at', 'updated_by', 'supplier', 'parent')


@event.listens_for(PointMatching, 'before_update')
def before_update(mapper, connection, target):
    handle_changes(connection, target, 'update')


@event.listens_for(PointMatching, 'after_insert')
def after_insert(mapper, connection, target):
    handle_changes(connection, target, 'insert')


def handle_changes(connection, target, action):
    state = inspect(target)

    changes = {
        'old': {},
        'new': {},
    }

    for attr in state.attrs:
        if attr.key in LOG_FIELDS_TO_SKIP:
            continue

        hist = attr.load_history()
        if not hist.has_changes():
            continue

        old_value = hist.deleted[0] if hist.deleted else None
        new_value = hist.added[0] if hist.added else None
        if attr.key == 'type':
            changes['old'][attr.key] = old_value.name if hasattr(old_value, 'name') else old_value
            changes['new'][attr.key] = new_value.name if hasattr(new_value, 'name') else new_value
        else:
            changes['old'][attr.key] = old_value
            changes['new'][attr.key] = new_value
    if len(changes['old']) > 0 or len(changes['new']) > 0:
        if action == 'insert':
            changes.pop('old')
            changes['new']['id'] = target.id
        connection.execute(insert(PointMatchingLog, {
            "point_matching_id": target.id,
            "action": action,
            "changes": changes,
            "author": target.updated_by_login,
        }))
