import csv
import logging
from contextlib import closing

import re
from datetime import datetime
from os.path import isfile, basename
from typing import Union

from tornado.options import define, options, print_help, Error

from travel.avia.flight_extras.application.db import db_master
from travel.avia.flight_extras.application.library.flight_passenger_experience_updater import FlightPassengerExperienceUpdater
from travel.avia.flight_extras.application.models import Flight, FlightPassengerExperience, Source
from travel.avia.flight_extras.settings import setup_logging

log = logging.getLogger(__name__)

re_flight = re.compile(r'[^0-9a-z]', re.IGNORECASE)

define('filename', help='Source filename')

PASSENGER_EXPERIENCE_REQUIRED_FIELDS = [
    'departure_date',
    'time_pass_dep',
    'time_pass_arr',
    'extra_day',
    'code_dep',
    'code_arr',
    'Predicted aircraft sub-type code',
    'Predicted total seats',
    'Predicted first seats',
    'Predicted biz seats',
    'Predicted prem eco seats',
    'Predicted eco seats',
    'Wi-fi Indicator First',
    'Wi-fi Indicator Business',
    'Wi-fi Indicator Premium Economy',
    'Wi-fi Indicator Economy',
    'In-seat power Indicator First',
    'In-seat power Indicator Business',
    'In-seat power Indicator Premium Economy',
    'In-seat power Indicator Economy',
    'IFE Indicator First',
    'IFE Indicator Business',
    'IFE Indicator Premium Economy',
    'IFE Indicator Economy',
    'Seat Pitch inches First',
    'Seat Pitch inches Business',
    'Seat Pitch inches Premium Economy',
    'Seat Pitch inches Economy'
]


def create_flight_passenger_experience(flight, data):
    # type: (Flight, dict) -> FlightPassengerExperience

    def get_data_field(d, field):
        # type: (dict, str) -> Union[str, bool, None]
        value = d.get(field, None)
        if value in (None, '', 'X', 'U'):
            return None
        elif value == 'Y':
            return True
        elif value == 'N':
            return False

        return value

    def inch_to_cm(inch):
        # type: (str) -> int
        return int(round(float(inch) * 2.54)) if inch else None

    return FlightPassengerExperience(
        flight=flight,
        flight_id=flight.id,
        departure_day=datetime.strptime(get_data_field(data, 'departure_date'), '%Y-%m-%d').date(),
        departure_time=get_data_field(data, 'time_pass_dep'),
        arrival_time=get_data_field(data, 'time_pass_arr'),
        extra_day=get_data_field(data, 'extra_day'),
        airport_from=get_data_field(data, 'code_dep'),
        airport_to=get_data_field(data, 'code_arr'),
        aircraft=get_data_field(data, 'Predicted aircraft sub-type code'),
        seats_total=get_data_field(data, 'Predicted total seats'),
        seats_first_class=get_data_field(data, 'Predicted first seats'),
        seats_business_class=get_data_field(data, 'Predicted biz seats'),
        seats_comfort=get_data_field(data, 'Predicted prem eco seats'),
        seats_economy=get_data_field(data, 'Predicted eco seats'),
        wifi_first_class=get_data_field(data, 'Wi-fi Indicator First'),
        wifi_business=get_data_field(data, 'Wi-fi Indicator Business'),
        wifi_comfort=get_data_field(data, 'Wi-fi Indicator Premium Economy'),
        wifi_economy=get_data_field(data, 'Wi-fi Indicator Economy'),
        power_first_class=get_data_field(data, 'In-seat power Indicator First'),
        power_business=get_data_field(data, 'In-seat power Indicator Business'),
        power_comfort=get_data_field(data, 'In-seat power Indicator Premium Economy'),
        power_economy=get_data_field(data, 'In-seat power Indicator Economy'),
        ife_first_class=get_data_field(data, 'IFE Indicator First'),
        ife_business=get_data_field(data, 'IFE Indicator Business'),
        ife_comfort=get_data_field(data, 'IFE Indicator Premium Economy'),
        ife_economy=get_data_field(data, 'IFE Indicator Economy'),
        seat_pitch_first_class=inch_to_cm(get_data_field(data, 'Seat Pitch inches First')),
        seat_pitch_business=inch_to_cm(get_data_field(data, 'Seat Pitch inches Business')),
        seat_pitch_comfort=inch_to_cm(get_data_field(data, 'Seat Pitch inches Premium Economy')),
        seat_pitch_economy=inch_to_cm(get_data_field(data, 'Seat Pitch inches Economy')),
    )


def get_previous_filename_and_datetime():
    with closing(db_master.create_session()) as session:
        result = session.query(
            Source.name,
            Source.created_at,
        ).filter(
            Source.name.like(Source.FLIGHT_STATS_PREFIX + '%')
        ).order_by(
            Source.created_at.desc(),
        ).first()
        if not result:
            return '', datetime.fromtimestamp(0)
        filename, created_at = result
    return filename, created_at


def _check_file_format(data_reader):
    # type: (csv.DictReader) -> None
    if any(f not in data_reader.fieldnames for f in PASSENGER_EXPERIENCE_REQUIRED_FIELDS):
        raise RuntimeError(
            'Flight stats file format changed. '
            'Expected fields: %s. '
            'Actual fields: %s'
        )


def upload_from_file(filename):
    if not isfile(filename):
        raise Error('File "{}" does not exists'.format(filename))

    with closing(db_master.create_session()) as session:
        source = Source.get_or_create(basename(filename), session)

        updater = FlightPassengerExperienceUpdater(session)

        count = 0
        with open(filename, 'r') as csv_file:
            data_reader = csv.DictReader(csv_file)
            _check_file_format(data_reader)

            for count, data in enumerate(data_reader, start=1):
                flight = Flight(
                    company_iata=re_flight.sub('', data.get('carrier')).upper(),
                    number=re_flight.sub('', data.get('flightno')).upper(),
                )
                pe = create_flight_passenger_experience(flight, data)
                pe.source_id = source.id

                updater.add_flight_passenger_experience(pe)
                if count % 10000 == 0:
                    log.info('Parsed %d rows', count)

        updater.flush()
        session.commit()
        log.info('Total parsed %d rows', count)


def main():
    setup_logging()

    if options.filename is not None:
        upload_from_file(options.filename)
    else:
        print_help()


if __name__ == '__main__':
    options.parse_command_line()

    main()
