import re
from datetime import datetime, date

from typing import List

from sqlalchemy import (
    Column, Boolean, Integer, BigInteger, SmallInteger, String, Date, DateTime, ForeignKey,
    UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship
from sqlalchemy.sql import func

from travel.avia.flight_extras.settings import REST_DATETIME_FORMAT, REST_DATE_FORMAT

re_iata = re.compile(r'[^0-9a-z]', re.IGNORECASE)
re_time = re.compile(r'[^0-9:]')

Base = declarative_base()


class ModelAsDictMixin(object):
    def as_dict(self):
        # type: () -> dict
        result = {}

        for c in self.__table__.columns:
            value = getattr(self, c.name)
            if isinstance(value, datetime):
                result[c.name] = value.strftime(REST_DATETIME_FORMAT)
            elif isinstance(value, date):
                result[c.name] = value.strftime(REST_DATE_FORMAT)
            else:
                result[c.name] = value

        if isinstance(self, FlightPassengerExperience):
            result['flight'] = self.flight.as_dict() if self.flight else None
            result['source'] = self.source.as_dict() if self.source else None

        return result


class Source(ModelAsDictMixin, Base):
    FLIGHT_STATS_PREFIX = 'FlightGlobal_flight_predictions_'

    __tablename__ = 'source'

    id = Column(BigInteger, primary_key=True)
    name = Column(String, nullable=False, unique=True)
    created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())

    @staticmethod
    def get_or_create(name, session):
        # type: (str, Session) -> Source
        q = session.query(Source).filter_by(name=name)
        s = q.first()
        if s:
            return s
        session.begin(nested=True)
        try:
            s = Source(name=name)
            session.add(s)
            session.commit()
            return s
        except IntegrityError:
            session.rollback()
            return q.first()


class Flight(ModelAsDictMixin, Base):
    __tablename__ = 'flight'
    __table_args__ = (
        UniqueConstraint('company_iata', 'number', name='ui_flight'),
    )

    id = Column(BigInteger, primary_key=True)

    company_iata = Column(String, nullable=False)
    number = Column(String, nullable=False)

    passenger_experiences = relationship('FlightPassengerExperience', back_populates='flight')
    infos = relationship('FlightInfo', back_populates='flight')

    created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
    updated_at = Column(DateTime(timezone=True), onupdate=func.now())

    @property
    def key(self):
        return '{} {}'.format(self.company_iata, self.number)

    def __repr__(self):
        return '<Flight: {} {}>'.format(self.company_iata, self.number)


class FlightPassengerExperience(ModelAsDictMixin, Base):
    __tablename__ = 'flight_passenger_experience'
    __table_args__ = (
        UniqueConstraint('flight_id', 'departure_day', name='ui_flight_departure_day'),
    )

    id = Column(BigInteger, primary_key=True)

    source_id = Column(BigInteger, ForeignKey(Source.id), nullable=False)
    source = relationship('Source')

    flight_id = Column(BigInteger, ForeignKey(Flight.id), nullable=False)
    flight = relationship('Flight', back_populates='passenger_experiences')
    departure_day = Column(Date, nullable=False)

    airport_from = Column(String)
    airport_to = Column(String)
    departure_time = Column(String)
    arrival_time = Column(String)
    extra_day = Column(SmallInteger)
    aircraft = Column(String)
    seats_total = Column(SmallInteger)
    seats_first_class = Column(SmallInteger)
    seats_business_class = Column(SmallInteger)
    seats_comfort = Column(SmallInteger)
    seats_economy = Column(SmallInteger)
    wifi_first_class = Column(Boolean)
    wifi_business = Column(Boolean)
    wifi_comfort = Column(Boolean)
    wifi_economy = Column(Boolean)
    power_first_class = Column(Boolean)
    power_business = Column(Boolean)
    power_comfort = Column(Boolean)
    power_economy = Column(Boolean)
    ife_first_class = Column(Boolean)
    ife_business = Column(Boolean)
    ife_comfort = Column(Boolean)
    ife_economy = Column(Boolean)
    seat_pitch_first_class = Column(SmallInteger)
    seat_pitch_business = Column(SmallInteger)
    seat_pitch_comfort = Column(SmallInteger)
    seat_pitch_economy = Column(SmallInteger)

    created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
    updated_at = Column(DateTime(timezone=True), onupdate=func.now())

    @property
    def key(self):
        return '{} {}'.format(self.flight, self.departure_day)

    @classmethod
    def get_insert_fields(cls):
        return (
            'flight_id',
            'source_id',
            'departure_day',
            'airport_from',
            'airport_to',
            'departure_time',
            'arrival_time',
            'extra_day',
            'aircraft',
            'seats_total',
            'seats_first_class',
            'seats_business_class',
            'seats_comfort',
            'seats_economy',
            'wifi_first_class',
            'wifi_business',
            'wifi_comfort',
            'wifi_economy',
            'power_first_class',
            'power_business',
            'power_comfort',
            'power_economy',
            'ife_first_class',
            'ife_business',
            'ife_comfort',
            'ife_economy',
            'seat_pitch_first_class',
            'seat_pitch_business',
            'seat_pitch_comfort',
            'seat_pitch_economy',
        )

    def get_insert_values_string(self):
        return '({})'.format(', '.join([
            str(self.flight_id),
            str(self.source_id),
            "'{}'".format(self.departure_day.strftime('%Y-%m-%d')),
            "'{}'".format(re_iata.sub('', self.airport_from)) if self.airport_from else 'null',
            "'{}'".format(re_iata.sub('', self.airport_to)) if self.airport_to else 'null',
            "'{}'".format(re_time.sub('', self.departure_time)) if self.departure_time else 'null',
            "'{}'".format(re_time.sub('', self.arrival_time)) if self.arrival_time else 'null',
            str(self.extra_day) if self.extra_day else 'null',
            "'{}'".format(self.aircraft) if self.aircraft else 'null',
            str(self.seats_total) if self.seats_total else 'null',
            str(self.seats_first_class) if self.seats_first_class else 'null',
            str(self.seats_business_class) if self.seats_business_class else 'null',
            str(self.seats_comfort) if self.seats_comfort else 'null',
            str(self.seats_economy) if self.seats_economy else 'null',
            str(self.wifi_first_class) if self.wifi_first_class is not None else 'null',
            str(self.wifi_business) if self.wifi_business is not None else 'null',
            str(self.wifi_comfort) if self.wifi_comfort is not None else 'null',
            str(self.wifi_economy) if self.wifi_economy is not None else 'null',
            str(self.power_first_class) if self.power_first_class is not None else 'null',
            str(self.power_business) if self.power_business is not None else 'null',
            str(self.power_comfort) if self.power_comfort is not None else 'null',
            str(self.power_economy) if self.power_economy is not None else 'null',
            str(self.ife_first_class) if self.ife_first_class is not None else 'null',
            str(self.ife_business) if self.ife_business is not None else 'null',
            str(self.ife_comfort) if self.ife_comfort is not None else 'null',
            str(self.ife_economy) if self.ife_economy is not None else 'null',
            str(self.seat_pitch_first_class) if self.seat_pitch_first_class is not None else 'null',
            str(self.seat_pitch_business) if self.seat_pitch_business is not None else 'null',
            str(self.seat_pitch_comfort) if self.seat_pitch_comfort is not None else 'null',
            str(self.seat_pitch_economy) if self.seat_pitch_economy is not None else 'null',
        ]))

    @classmethod
    def get_multi_insert_sql(cls, exps, where):
        # type: (List[FlightPassengerExperience], str) -> str

        sets = ['{field} = EXCLUDED.{field}'.format(field=field) for field in cls.get_insert_fields()]
        sets.append('updated_at = now()')

        sql = """
              insert into {table} (
                  {fields}
              )
              values {values}
              on conflict (
                  flight_id, departure_day
              )
              do update set {sets}
              where
                  {where}
              """.format(
            table=cls.__tablename__,
            fields=', '.join(cls.get_insert_fields()),
            values=', '.join(exp.get_insert_values_string() for exp in exps),
            sets=',\n'.join(sets),
            where=where,
        )

        return sql

    def __repr__(self):
        return '<FlightPassengerExperience: {} at {}>'.format(self.flight, self.departure_day)


class FlightInfo(ModelAsDictMixin, Base):
    __tablename__ = 'flight_info'
    __table_args__ = (
        UniqueConstraint('flight_id', 'departure_day', name='ui_flight_info'),
    )

    id = Column(BigInteger, primary_key=True)

    flight_id = Column(BigInteger, ForeignKey(Flight.id), nullable=False)
    flight = relationship('Flight', back_populates='infos')
    departure_day = Column(Date, nullable=False)

    turbulence_index = Column(SmallInteger)
    direct_distance = Column(Integer)
    distance = Column(Integer)
    average_speed = Column(SmallInteger)
    duration = Column(Integer)

    turbulence_zones = Column(JSONB)
    route = Column(JSONB)
    sun = Column(JSONB)
    wind = Column(JSONB)
    weather_from = Column(JSONB)
    weather_to = Column(JSONB)
    sights = Column(JSONB)
    aircraft = Column(JSONB)
    airport_from = Column(JSONB)
    airport_to = Column(JSONB)

    created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
    updated_at = Column(DateTime(timezone=True), onupdate=func.now())

    @property
    def modified_at(self):
        return self.updated_at or self.created_at
