# coding=utf-8
from __future__ import unicode_literals

import logging
import os
import psycopg2
import pytz
import requests
import retrying
import six
from datetime import datetime, timedelta
from dateutil import parser


logger = logging.getLogger(__name__)

rasp_p2p_prefix = 'https://morda-backend.rasp.yandex.net/ru/search/search/'
sf_p2p_template = 'http://shared-flights.{}.avia.yandex.net/api/v1/flight-p2p'
date_format = '%Y-%m-%dT%H:%M:%S%z'

block_header = '{},,,,,,,'
group_header = 'from_to,t_m2,t_m1,t,t_1,t_2,t_7,t_30'

PGAAS_CLUSTER_ID = os.getenv('AVIA_PGAAS_CLUSTER_ID', 'mdb9sssbmtcje8gtvlrc')
PGAAS_DATABASE_NAME = os.getenv('AVIA_PGAAS_DATABASE_NAME', 'shared-flights')
PGAAS_USER = os.getenv('AVIA_PGAAS_USER', 'avia')
PGAAS_PASSWORD = os.getenv('AVIA_PGAAS_PASSWORD')
PGAAS_PORT = os.getenv('AVIA_PGAAS_PORT', '6432')
SSL_ROOT_CERT = os.getenv('AVIA_PGAAS_SSL_ROOT_CERT', '/etc/ssl/certs/ca-certificates.crt')

default_stations_from = ['SVO', 'VKO', 'DME', 'LED', 'ZIA', 'AER']
default_stations_to = [
    'KRR', 'SVX', 'UFA', 'TAS', 'AER', 'EVN', 'OVB', 'FRU', 'KUF', 'KZN', 'SIP', 'IKT',
    'ROV', 'VOG', 'HKT', 'KJA', 'CEK', 'ALA', 'AYT', 'VVO', 'NQZ', 'SGC', 'TJM', 'MSQ',
    'GOJ', 'VOZ', 'GSV', 'PEE', 'MCX', 'MRV', 'REN', 'YKS', 'NUX', 'KGD', 'IST', 'MMK',
    'KHV', 'DXB', 'OSS', 'ASF', 'GUW', 'SCW', 'OMS', 'USK', 'KEJ', 'OGZ', 'NJC', 'LED',
    'AAQ', 'IJK', 'EGO', 'CXR', 'BAX', 'KIV', 'MUC', 'STW', 'DYU', 'BKK', 'NOJ', 'NBC',
    'GOI', 'RIX', 'FRA', 'CSY', 'AMS', 'PKC', 'CIT', 'CEE', 'TLV', 'SCO', 'PRG', 'ARH',
    'CDG', 'KBP', 'AKX', 'HTA', 'TOF', 'UUS', 'SAW', 'MJZ', 'UTP', 'BQS', 'SKD', 'LBD',
    'ULV', 'BZK', 'KVX', 'NYM', 'NOZ', 'KLF', 'URA', 'NNM', 'SLY', 'PEZ', 'JFK', 'LHR',
    9600377,
]


class Flight(object):

    def __init__(self):
        self.number = ''
        self.company_id = 0
        self.departure_utc = ''
        self.arrival_utc = ''
        self.station_to = 0
        self.station_from = 0
        self.transport_model = ''
        self.is_codeshare = False

    def __repr__(self):
        return str(self.__dict__)


class ComparisonResult(object):

    def __init__(self):
        self.same_flights = 0
        self.rasp_but_not_sf = 0
        self.sf_but_not_rasp = 0
        self.unequal_flights = 0

    def __repr__(self):
        return '{},{},{},{}'.format(self.same_flights, self.rasp_but_not_sf, self.sf_but_not_rasp, self.unequal_flights)


class QualityTool(object):
    """ Compares results from rasp and shared-flights handles. """

    def __init__(self, environment, stations_from, stations_to, pg_password, debug):
        self._environment = environment
        self._debug = debug
        self._pg_password = PGAAS_PASSWORD
        if pg_password:
            self._pg_password = pg_password

        self._stations_from = default_stations_from
        if stations_from:
            self._stations_from = stations_from.split(',')

        self._stations_to = default_stations_to
        if stations_to:
            self._stations_to = stations_to.split(',')

        self._transport_models = {}
        self._stations = {}
        self._lines_deltas = []
        self._lines_counts = []
        self._lines_rasp_but_not_sf = []
        self._lines_sf_but_not_rasp = []
        self._lines_unequal = []

    def run(self):
        logger.info('Started comparing handles')
        logger.info('Debug %s', self._debug)

        # connect to database
        conn = psycopg2.connect(self.get_psycopg2_conn_string())
        self._transport_models = self.fetch_transport_models_to_dict(conn.cursor())
        self._stations = self.fetch_stations_to_dict(conn.cursor())

        logger.info('Done comparing handles')

        self.run_p2p_one_day()
        lines = []
        self.extend_lines(lines)
        return '\n'.join(lines)

    def run_p2p_one_day(self):
        cur_index = 0
        for station_from in self._stations_from:
            for station_to in self._stations_to:
                logger.info('Current progress %d of %d', cur_index, len(self._stations_from)*len(self._stations_to))
                cur_index += 1
                station_from_id = station_from
                station_to_id = station_to
                if isinstance(station_from, six.string_types):
                    station_from_id = self._stations[station_from]
                if isinstance(station_to, six.string_types):
                    station_to_id = self._stations[station_to]
                if station_from_id == station_to_id:
                    continue
                self.add_lines(self.run_p2p_one_day_for_stations(station_from_id, station_to_id), station_from, station_to)
                self.add_lines(self.run_p2p_one_day_for_stations(station_to_id, station_from_id), station_to, station_from)

    def extend_lines(self, lines):
        # deltas block
        self.print_lines_block(lines, self._lines_deltas, 'deltas')
        # counts block
        self.print_lines_block(lines, self._lines_counts, 'counts')
        # rasp but not sf block
        self.print_lines_block(lines, self._lines_rasp_but_not_sf, 'rasp_but_not_sf')
        # sf but not rasp block
        self.print_lines_block(lines, self._lines_sf_but_not_rasp, 'sf_but_not_rasp')
        # unequal flights block
        self.print_lines_block(lines, self._lines_unequal, 'unequal_flights')

    def print_lines_block(self, text_lines, numeric_lines, block_title):
        text_lines.append(block_header.format(''))  # separator
        text_lines.append(block_header.format(block_title))
        text_lines.append(group_header)
        for numeric_line in numeric_lines:
            str_cells = [str(x) for x in numeric_line]
            text_lines.append(','.join(str_cells))

    def add_lines(self, cmp_results, station_from_str, station_to_str):
        non_zero = False
        for res in cmp_results:
            if res.sf_but_not_rasp or res.unequal_flights or res.rasp_but_not_sf or res.same_flights:
                non_zero = True
                break

        if not non_zero:
            return

        from_to = '{}_{}'.format(station_from_str, station_to_str)

        line_deltas = [from_to, 0, 0, 0, 0, 0, 0, 0]
        line_counts = [from_to, 0, 0, 0, 0, 0, 0, 0]
        line_rasp_but_not_sf = [from_to, 0, 0, 0, 0, 0, 0, 0]
        line_sf_but_not_rasp = [from_to, 0, 0, 0, 0, 0, 0, 0]
        line_unequal = [from_to, 0, 0, 0, 0, 0, 0, 0]

        for zero_idx in range(0, len(cmp_results)):
            res = cmp_results[zero_idx]
            index = zero_idx + 1
            line_deltas[index] += res.rasp_but_not_sf + res.sf_but_not_rasp + res.unequal_flights
            line_counts[index] += res.same_flights + res.rasp_but_not_sf
            line_rasp_but_not_sf[index] += res.rasp_but_not_sf
            line_sf_but_not_rasp[index] += res.sf_but_not_rasp
            line_unequal[index] += res.unequal_flights

        self._lines_deltas.append(line_deltas)
        self._lines_counts.append(line_counts)
        self._lines_rasp_but_not_sf.append(line_rasp_but_not_sf)
        self._lines_sf_but_not_rasp.append(line_sf_but_not_rasp)
        self._lines_unequal.append(line_unequal)

    def run_p2p_one_day_for_stations(self, station_from, station_to):
        today = datetime.today()
        r1 = self.run_p2p_one_day_for_date_and_stations(today + timedelta(days=-2), station_from, station_to)
        r2 = self.run_p2p_one_day_for_date_and_stations(today + timedelta(days=-1), station_from, station_to)
        r3 = self.run_p2p_one_day_for_date_and_stations(today, station_from, station_to)
        r4 = self.run_p2p_one_day_for_date_and_stations(today + timedelta(days=1), station_from, station_to)
        r5 = self.run_p2p_one_day_for_date_and_stations(today + timedelta(days=2), station_from, station_to)
        r6 = self.run_p2p_one_day_for_date_and_stations(today + timedelta(days=7), station_from, station_to)
        r7 = self.run_p2p_one_day_for_date_and_stations(today + timedelta(days=30), station_from, station_to)
        return [r1, r2, r3, r4, r5, r6, r7]

    def run_p2p_one_day_for_date_and_stations(self, date_when, station_from, station_to):
        after = date_when.strftime('%Y-%m-%d')
        before = (date_when + timedelta(days=1)).strftime('%Y-%m-%d')
        rasp_result = self.parse_rasp_p2p_result(
            self.make_request(self.get_rasp_p2p_url(station_from, station_to, after)),
            station_from,
            station_to,
        )
        sf_result = self.parse_sf_p2p_result(self.make_request(self.get_sf_p2p_url(station_from, station_to, after, before)))
        cmp_result = self.compare_responses(rasp_result, sf_result)
        return cmp_result

    def get_rasp_p2p_url(self, station_from, station_to, date_when):
        return '{}?pointFrom=s{}&pointTo=s{}&transportType=plane&when={}'.format(
            rasp_p2p_prefix,
            station_from,
            station_to,
            date_when,
        )

    def get_sf_p2p_url(self, station_from, station_to, date_when, date_when_next_day):
        return '{}?from={}&to={}&after={}T00:00:00&before={}T04:00:00'.format(
            sf_p2p_template.format(self._environment),
            station_from,
            station_to,
            date_when,
            date_when_next_day,
        )

    @retrying.retry(
        retry_on_result=lambda response: not response,
        stop_max_attempt_number=12,
        wait_fixed=1000,
    )
    def make_request(self, url):
        try:
            if self._debug:
                logger.info('Request: %s', url)
            response = requests.get(url)
            if response.status_code != requests.codes.ok:
                logger.info('Unable to fetch from url: %s, result: %s', url, response.status_code)
                return None
            return response.json()
        except Exception as e:
            logger.info('Request error %r', e)
            return None

    def parse_rasp_p2p_result(self, json, station_from, station_to):
        if not json:
            return {}
        if not json.get('result'):
            return {}
        json = json.get('result')
        if not json.get('segments'):
            return {}
        result = {}
        for segment in json.get('segments'):
            if not segment.get('number'):
                continue
            flight = Flight()
            flight.number = segment.get('number')
            if segment.get('company'):
                flight.company_id = segment.get('company').get('id')
            if segment.get('departure'):
                try:
                    flight.departure_utc = parser.parse(segment.get('departure')).strftime(date_format)
                except Exception as e:
                    logger.info('Rasp date conversion error %r', e)
            if segment.get('arrival'):
                try:
                    flight.arrival_utc = parser.parse(segment.get('arrival')).strftime(date_format)
                except Exception as e:
                    logger.info('Rasp date conversion error %r', e)
            if segment.get('stationFrom'):
                flight.station_from = segment.get('stationFrom').get('id')
                if not flight.station_from or flight.station_from != station_from:
                    continue
            if segment.get('stationTo'):
                flight.station_to = segment.get('stationTo').get('id')
                if not flight.station_to or flight.station_to != station_to:
                    continue
            '''
            if segment.get('transport'):
                if segment.get('transport').get('model'):
                    flight.transport_model = segment.get('transport').get('model').get('title')
                    if flight.transport_model:
                        flight.transport_model = flight.transport_model.lower()
                        flight.transport_model = six.ensure_binary(flight.transport_model.replace('а', 'a'))
            '''
            result[flight.number] = flight
        return result

    def parse_sf_p2p_result(self, json):
        if not json:
            return {}
        if not json.get('flights'):
            return {}
        result = {}
        for json_flight in json.get('flights'):
            if not json_flight.get('number'):
                continue
            flight = Flight()
            flight.number = json_flight.get('title')
            if json_flight.get('airlineID'):
                flight.company_id = json_flight.get('airlineID')
            if json_flight.get('departureDatetime'):
                try:
                    departure_date = parser.parse(json_flight.get('departureDatetime'))
                    flight.departure_utc = departure_date.astimezone(pytz.utc).strftime(date_format)
                except Exception as e:
                    logger.info('Sf date conversion error %r', e)
            if json_flight.get('arrivalDatetime'):
                try:
                    arrival_date = parser.parse(json_flight.get('arrivalDatetime'))
                    flight.arrival_utc = arrival_date.astimezone(pytz.utc).strftime(date_format)
                except Exception as e:
                    logger.info('Sf date conversion error %r', e)
            '''
            # Add back when route validation is possible
            if json_flight.get('route'):
                flight.station_from = json_flight.get('route')[0]
                flight.station_to = json_flight.get('route')[-1]
            '''
            if json_flight.get('departureStation'):
                flight.station_from = json_flight.get('departureStation')
            if json_flight.get('arrivalStation'):
                flight.station_to = json_flight.get('arrivalStation')
            '''
            # Add back when aircraft type validation is possible
            if json_flight.get('aircraftType'):
                flight.transport_model = self._transport_models.get(int(json_flight.get('aircraftType')))
                if flight.transport_model:
                    flight.transport_model = flight.transport_model.lower()
            '''
            result[flight.number] = flight
            if json_flight.get('codeshares'):
                for elem in json_flight.get('codeshares'):
                    if elem.get('title'):
                        codeshare = Flight()
                        codeshare.is_codeshare = True
                        codeshare.number = elem.get('title')
                        codeshare.station_from = flight.station_from
                        codeshare.station_to = flight.station_to
                        codeshare.transport_model = flight.transport_model
                        codeshare.departure_utc = flight.departure_utc
                        codeshare.arrival_utc = flight.arrival_utc
                        result[codeshare.number] = codeshare
        return result

    def compare_responses(self, rasp_result, sf_result):
        cmp_result = ComparisonResult()
        for rasp_flight_title, rasp_flight in six.iteritems(rasp_result):
            sf_flight = sf_result.get(rasp_flight_title)
            if not sf_flight:
                cmp_result.rasp_but_not_sf += 1
                if self._debug:
                    logger.info('Rasp but not sf: %s', rasp_flight)
                continue
            if sf_flight.is_codeshare or str(rasp_flight) == str(sf_flight):
                cmp_result.same_flights += 1
            else:
                cmp_result.unequal_flights += 1
                if self._debug:
                    logger.info('Unequal flight: \n%s != \n%s', rasp_flight, sf_flight)
            sf_result.pop(rasp_flight_title)
        operating_sf = {key: value for (key, value) in sf_result.items() if not value.is_codeshare}
        cmp_result.sf_but_not_rasp = len(operating_sf)
        if self._debug and cmp_result.sf_but_not_rasp > 0:
            logger.info('Sf but not rasp: %s', sf_result)
        return cmp_result

    def fetch_transport_models_to_dict(self, cursor):
        logger.info('Start fetching transport models')
        transport_models = {}
        # iterate through transport_models
        cursor.execute(
            '''
            select
                id,
                title_en
            from
                transport_model
            ''')
        for row in cursor:
            transport_models[row[0]] = row[1]

        logger.info('Done fetching transport models')
        return transport_models

    def fetch_stations_to_dict(self, cursor):
        logger.info('Start fetching stations')
        stations = {}
        # iterate through stations
        cursor.execute(
            '''
            select
                id,
                iata
            from
                station_with_codes
            where
                iata != ''
            ''')
        for row in cursor:
            stations[row[1]] = row[0]

        logger.info('Done fetching stations')
        return stations

    def get_psycopg2_conn_string(self):
        conn_string = ' '.join([
            'dbname={}'.format(PGAAS_DATABASE_NAME),
            'user={}'.format(PGAAS_USER),
            'host={}'.format('c-{cluster_id}.rw.db.yandex.net'.format(cluster_id=PGAAS_CLUSTER_ID)),
            'port={}'.format(PGAAS_PORT),
            'sslmode=require',
            'password={}'.format(self._pg_password),
        ])
        return conn_string
