# coding=utf-8
from __future__ import unicode_literals

import json
import logging
import os
import psycopg2
import six

from collections import defaultdict, namedtuple, OrderedDict
from datetime import datetime, timedelta

import yt.wrapper as yt


logger = logging.getLogger(__name__)

DATE_FORMAT = '%Y-%m-%d'
DATETIME_FORMAT = '%Y-%m-%dT%H:%M:%S%z'
TIME_FORMAT = '%H:%M'
MAX_DAYS_TO_PROCESS_IN_A_SINGLE_RUN = 183
TRUSTED_PARTNER = 261
MAX_ROWS_PER_SELECT = 1000000

OAUTH_TOKEN = os.getenv('AVIA_OAUTH_TOKEN')
YT_TMP_DIRECTORY = '//home/avia/tmp'
YT_PROXY = 'hahn.yt.yandex.net'

FLIGHTSTORAGE_CLUSTER_ID = os.getenv('AVIA_FLIGHTSTORAGE_CLUSTER_ID', '1cd61885-7019-4696-bd25-e76651ca5d56')
FLIGHTSTORAGE_DATABASE_NAME = os.getenv('AVIA_FLIGHTSTORAGE_DATABASE_NAME', 'avia_flight_board')
FLIGHTSTORAGE_USER = os.getenv('AVIA_FLIGHTSTORAGE_USER', 'avia')
FLIGHTSTORAGE_PORT = os.getenv('AVIA_FLIGHTSTORAGE_PORT', '6432')


Company = namedtuple('Company', ['id', 'iata', 'rasp_id'])
Airport = namedtuple('Airport', ['id', 'iata', 'rasp_id', 'tz_name'])
ScheduleSegment = namedtuple('ScheduleSegment', ['departure_day', 'airport_from_id', 'airport_to_id'])
SegmentKey = namedtuple('SegmentKey', ['company_id', 'flight_number', 'airport_from_id', 'airport_to_id'])
ExistingRecord = namedtuple('ExistingRecord', ['record_index', 'partners_count', 'has_trusted_partner'])


class SchedulesDumper(object):
    """ Dumps flight schedules for a specified period to YT """

    def __init__(self, output_path, oauth_token, pg_password, max_rows_per_select):
        self._oauth_token = OAUTH_TOKEN
        if oauth_token:
            self._oauth_token = oauth_token
        self._pg_password = pg_password
        if not self._pg_password:
            self._pg_password = os.getenv('AVIA_FLIGHTSTORAGE_PASSWORD')

        if not output_path:
            raise Exception('Output path is not specified')
        self._output_path = output_path

        self._max_rows_per_select = max_rows_per_select if max_rows_per_select else MAX_ROWS_PER_SELECT
        self._airports_by_iata = {}
        self._flights_by_id = None

    def run(self, start_flight_date, end_flight_date):
        logger.info('Started dumping schedules')

        if not start_flight_date or not end_flight_date:
            logger.error('Either start or end date is not specified: start=%s, end=%s', start_flight_date, end_flight_date)
            return

        start_date = datetime.strptime(start_flight_date, DATE_FORMAT)
        end_date = datetime.strptime(end_flight_date, DATE_FORMAT)

        if end_date < start_date:
            logger.error('Start date is after end date: start=%s, end=%s', start_date, end_date)
            return

        if (start_date-end_date).days > MAX_DAYS_TO_PROCESS_IN_A_SINGLE_RUN:
            logger.error(
                'Too many days to process(%s), max allowed is %s.',
                (start_date-end_date).days,
                MAX_DAYS_TO_PROCESS_IN_A_SINGLE_RUN,
            )
            return

        current_date = start_date
        while current_date <= end_date:
            self.run_for_single_flight_date(current_date.strftime(DATE_FORMAT))
            current_date = current_date + timedelta(days=1)

        logger.info('Done dumping schedules')

    def run_for_single_flight_date(self, flight_date):
        logger.info('Started processing date %s', flight_date)
        conn = psycopg2.connect(self.get_psycopg2_conn_string())
        error_counters = defaultdict(int)

        logger.info('Loading references')
        companies = load_companies(conn.cursor())
        airports_by_id, self._airports_by_iata = load_airports(conn.cursor())

        if not self._flights_by_id:
            self._flights_by_id = load_flights_by_id(conn.cursor())

        dates_to_segments, single_segment_flights, schedules_errors = self.load_flight_schedules(
            conn.cursor(name='flight_schedules'),
            flight_date,
        )
        if schedules_errors:
            merge_errors(error_counters, schedules_errors)

        records = []

        if flight_date not in dates_to_segments.keys():
            logger.info('Adding flight-date %s', flight_date)
            dates_to_segments[flight_date]={}

        for current_flight_date, current_segments in six.iteritems(dates_to_segments):
            logger.info('Fetching flights for the day %s', current_flight_date)
            processed_segments = {}
            cursor = conn.cursor(name='flight_segments')
            cursor.execute(
                '''
                select
                    company_id,
                    number,
                    airport_from_id,
                    airport_to_id,
                    departure_time,
                    arrival_day,
                    arrival_time,
                    departure_utc,
                    arrival_utc,
                    partners
                from
                    flights
                where
                    departure_day = %s
                limit %s
                ''',
                (current_flight_date, self._max_rows_per_select),
            )
            row_index = 0
            row_types = defaultdict(int)
            c = get_column_names(cursor)
            for row in cursor:
                company = companies.get(int(row[c['company_id']])) if row[c['company_id']] else None

                airport_from = airports_by_id.get(int(row[c['airport_from_id']])) if row[c['airport_from_id']] else None
                airport_from_id = airport_from.rasp_id if airport_from else None

                airport_to = airports_by_id.get(int(row[c['airport_to_id']])) if row[c['airport_to_id']] else None
                airport_to_id = airport_to.rasp_id if airport_to else None

                arrival_day = format_date(row[c['arrival_day']], DATE_FORMAT)
                flight_number = row[c['number']] if row[c['number']] else None
                if not flight_number:
                    error_counters['no_flight_number'] += 1
                    continue
                segment_key = SegmentKey(company.id, flight_number, airport_from, airport_to)

                existing_record = processed_segments.get(segment_key)
                partners = row[c['partners']]
                if existing_record:
                    if not partners:
                        row_types['duplicate-by-partners'] += 1
                        continue
                    if TRUSTED_PARTNER not in partners:
                        if existing_record.has_trusted_partner or len(partners) <= existing_record.partners_count:
                            row_types['duplicate-by-partners'] += 1
                            continue
                    if TRUSTED_PARTNER in partners:
                        if existing_record.has_trusted_partner and len(partners) <= existing_record.partners_count:
                            row_types['duplicate-by-partners'] += 1
                            continue
                    row_types['overridden-by-partners'] += 1

                record = {
                    'marketing_carrier_id' : company.rasp_id if company else None,
                    'marketing_carrier_code' : company.iata if company else None,
                    'marketing_flight_number' : flight_number,
                    'airport_from_id' : airport_from_id,
                    'airport_to_id' : airport_to_id,
                    'segment_number' : 1,
                    'segment_count' : 1,
                    'flight_date' : flight_date,
                    'segment_departure_date' : current_flight_date,
                    'segment_arrival_date' : arrival_day,
                    'departure_time' : row[c['departure_time']].strftime(TIME_FORMAT) if row[c['departure_time']] else None,
                    'departure_time_utc' : format_date(row[c['departure_utc']], DATETIME_FORMAT),
                    'arrival_time' : row[c['arrival_time']].strftime(TIME_FORMAT) if row[c['arrival_time']] else None,
                    'arrival_time_utc' : format_date(row[c['arrival_utc']], DATETIME_FORMAT),
                    'data_source' : 'flight-storage',
                    'record_source' : 'flight-storage',
                }

                segment_data = current_segments.get(segment_key)
                new_segments_count = 0 if existing_record else 1
                if segment_data:
                    record['segment_number'] = segment_data[0]
                    record['segment_count'] = segment_data[1]
                    row_index += 1
                    row_types['multi-segments'] += new_segments_count
                elif current_flight_date == flight_date and (company.id, flight_number) in single_segment_flights:
                    row_index += 1
                    row_types['single-segments'] += new_segments_count
                else:
                    row_types['skipped_segments'] += 1
                    continue

                if existing_record:
                    records[existing_record.record_index] = record
                else:
                    processed_segments[segment_key]=ExistingRecord(
                        len(records),
                        len(partners) if partners else 0,
                        TRUSTED_PARTNER in partners if partners else False
                    )
                    records.append(record)

                if row_index and row_index % 50000 == 0:
                    logger.info('Loaded %d flights so far...', row_index)

            cursor.close()
            logger.info('Loaded row-types: %s', row_types)

        yt_client = self.get_yt_client()
        table_path = '/'.join([self._output_path, flight_date])
        with yt_client.Transaction():
            if yt_client.exists(table_path):
                yt_client.remove(table_path)

            fields = OrderedDict([
                ('marketing_carrier_id', 'int32'),
                ('marketing_carrier_code', 'string'),
                ('marketing_flight_number', 'string'),
                ('segment_number', 'int8'),
                ('segment_count', 'int8'),
                ('airport_from_id', 'int32'),
                ('airport_to_id', 'int32'),
                ('flight_date', 'string'),
                ('segment_departure_date', 'string'),
                ('segment_arrival_date', 'string'),
                ('departure_time', 'string'),
                ('departure_time_utc', 'string'),
                ('arrival_time', 'string'),
                ('arrival_time_utc', 'string'),
                ('departure_terminal', 'string'),
                ('arrival_terminal', 'string'),
                ('aircraft_model', 'string'),
                ('aircraft_model_id', 'int32'),
                ('filing_carrier_id', 'int32'),
                ('is_codeshare', 'boolean'),
                ('domestic', 'string'),
                ('operating_carrier_id', 'int32'),
                ('operating_carrier_code', 'string'),
                ('operating_flight_number', 'string'),
                ('operating_segment_number', 'int8'),
                ('data_source', 'string'),
                ('record_source', 'string'),
            ])

            schema = [
                {
                    'name': name,
                    'type': type_value,
                } for name, type_value in six.iteritems(fields)
            ]
            yt_client.create('table', table_path, recursive=True, attributes={
                'optimize_for': 'scan',
                'schema': schema,
            })
            yt_client.write_table(table_path, records)

        logger.info('Problems detected: %s', error_counters)
        logger.info('Done processing date %s', flight_date)

    # Fills in the following structures:
    # dates_to_segments - dict that maps a flight date to a list of segments we need to fetch from the 'segments' table
    # single_segment_flights - list of (company, flight_number) tuples that are single-segment flights operating on the flight date
    # errors - dict of json parsing errors, if any
    def load_flight_schedules(self, cursor, flight_date):
        logger.info('Loading flight schedules')
        cursor.execute(
            '''
            select
                flight_id,
                airport_from_id,
                airport_to_id,
                segments
            from
                flight_schedule
            where
                departure_day = %s
            limit %s
            ''',
            (flight_date, self._max_rows_per_select),
        )

        dates_to_segments = defaultdict(dict)
        single_segment_flights = set()
        errors = defaultdict(int)

        row_index = 0
        for row in cursor:
            row_index += 1
            if row_index % 50000 == 0:
                logger.info('Loaded %d schedules so far...', row_index)

            flight_id = int(row[0]) if row[0] else None
            # company_id, flight_number
            flight_value = self._flights_by_id.get(flight_id)
            if not flight_value:
                errors['unknown_flight_id: {}'.format(flight_id)] += 1
                continue
            company_id = flight_value[0]
            flight_number = flight_value[1]

            airport_from_id = int(row[1]) if row[1] else None
            airport_to_id = int(row[2]) if row[2] else None
            segments = None
            if row[3]:
                segments = self.parse_segments(row[3], errors)
            if segments and len(segments) > 1:
                # got a multi-segment flight
                segment_index = 0
                segments_count = len(segments)
                for segment in segments:
                    segment_index += 1
                    segment_key = SegmentKey(
                        company_id,
                        flight_number,
                        segment.airport_from_id,
                        segment.airport_to_id,
                    )
                    dates_to_segments[segment.departure_day][segment_key] = (segment_index, segments_count)
            else:
                # single-segment flight
                single_segment_flights.add((company_id, flight_number))

        cursor.close()
        logger.info('single_segment_flights: %d', len(single_segment_flights))
        logger.info('dates: %s', dates_to_segments.keys())
        logger.info('errors: %s', errors)
        return dates_to_segments, single_segment_flights, errors

    def parse_segments(self, segments_text, errors):
        if not segments_text:
            return None
        try:
            segments_list = None
            if isinstance(segments_text, list):
                segments_list = segments_text
            else:
                segments_list = json.loads(segments_text)
            segments = []
            for segment_elem in segments_list:
                departure_day = segment_elem.get('departure_day')
                airport_from_code = segment_elem.get('airport_from_code')
                airport_from_id = self._airports_by_iata.get(airport_from_code) if airport_from_code else None
                if airport_from_code and not airport_from_id:
                    errors['unknown airport from code {}'.format(airport_from_code)] += 1
                airport_to_code = segment_elem.get('airport_to_code')
                airport_to_id = self._airports_by_iata.get(airport_to_code) if airport_to_code else None
                if airport_to_code and not airport_to_id:
                    errors['unknown airport to code {}'.format(airport_to_code)] += 1

                if departure_day and airport_from_id and airport_to_id:
                    segments.append(ScheduleSegment(departure_day, airport_from_id, airport_to_id))
                else:
                    errors['invalid segment: {}'.format(segment_elem)] += 1
            return segments
        except Exception as e:
            errors['json parsing error: {}'.format(e)] += 1
            return None

    def get_yt_client(self):
        config = {
            'clear_local_temp_files': False,
            'remote_temp_tables_directory': YT_TMP_DIRECTORY,
        }
        return yt.YtClient(proxy=YT_PROXY, token=self._oauth_token, config=config)

    def get_psycopg2_conn_string(self):
        conn_string = ' '.join([
            'dbname={}'.format(FLIGHTSTORAGE_DATABASE_NAME),
            'user={}'.format(FLIGHTSTORAGE_USER),
            'host={}'.format('c-{cluster_id}.rw.db.yandex.net'.format(cluster_id=FLIGHTSTORAGE_CLUSTER_ID)),
            'port={}'.format(FLIGHTSTORAGE_PORT),
            'sslmode=require',
            'password={}'.format(self._pg_password),
        ])
        return conn_string


def load_companies(cursor):
    cursor.execute(
        '''
        select
            id,
            iata,
            rasp_id
        from
            companies
        '''
    )
    companies = {}
    for row in cursor:
        row_id = int(row[0]) if row[0] else None
        iata = row[1] if row[1] else None
        rasp_id = int(row[2]) if row[2] else None
        if row_id:
            companies[row_id] = Company(row_id, iata, rasp_id)
    cursor.close()
    return companies


def load_airports(cursor):
    cursor.execute(
        '''
        select
            id,
            iata,
            rasp_id,
            tz_name
        from
            airports
        '''
    )
    airports_by_id = {}
    airports_by_iata = {}
    for row in cursor:
        row_id = int(row[0]) if row[0] else None
        iata = six.ensure_text(row[1]) if row[1] else None
        rasp_id = int(row[2]) if row[2] else None
        tz_name = row[3] if row[3] else None
        airport = Airport(row_id, iata, rasp_id, tz_name)
        if row_id:
            airports_by_id[row_id] = airport
        if iata:
            airports_by_iata[iata] = airport
    cursor.close()
    return airports_by_id, airports_by_iata


def load_flights_by_id(cursor):
    cursor.execute(
        '''
        select
            id,
            company_id,
            number
        from
            flight
        '''
    )
    flights_by_id = {}
    for row in cursor:
        flight_id = int(row[0]) if row[0] else None
        company = int(row[1]) if row[1] else None
        flight_number = row[2] if row[2] else None
        if flight_id and company and flight_number:
            flights_by_id[flight_id] = (company, flight_number)
    cursor.close()
    return flights_by_id


def merge_errors(errors, segments_errors):
    for k, v in segments_errors.items():
        errors[k] += v


def get_column_names(cursor):
    try:
        cursor.fetchone()
    except:
        pass
    colnames = {}
    index = 0
    if cursor.description:
        for desc in cursor.description:
            colnames[desc[0].lower()] = index
            index += 1
    try:
        cursor.scroll(-1)
    except:
        pass
    return colnames


def format_date(date_obj, format):
    if date_obj and date_obj.year > 2000:
        return date_obj.strftime(format)
    return None
