# coding=utf-8
import io
import logging
import os
import psycopg2
import pytz
from datetime import date, datetime, timedelta
from typing import IO, Optional

from google.protobuf import text_format

from travel.library.python.dicts.file_util import write_binary_string
from travel.proto.shared_flights.flying_time.flying_time_pb2 import TFlyingTime


logger = logging.getLogger(__name__)
DATE_FORMAT = '%Y-%m-%d'
TEXT_OUTPUT_MODE = 'text'

ENV_PGAAS_CLUSTER_ID = 'AVIA_PGAAS_CLUSTER_ID'
ENV_PGAAS_DATABASE_NAME = 'AVIA_PGAAS_DATABASE_NAME'
ENV_PGAAS_USER = 'AVIA_PGAAS_USER'
ENV_PGAAS_PASSWORD = 'AVIA_PGAAS_PASSWORD'
ENV_PGAAS_PORT = 'AVIA_PGAAS_PORT'


class FlyingTimeTool(object):
    """Calculates minimum flight time between any two settlements for any date"""

    def __init__(
        self,
        environment: str,
        output_mode: str,
        output_file: str,
        records_limit: int,
    ):
        """
        :param environment: "testing"/"production"
        :param output_mode: TEXT_OUTPUT_MODE/"binary"(default), text is for debugging only
        :param output_file: only for local debug runs, to manually validate the output
        :param: records_limit: only for local debug runs, to manually validate the output
        """
        self._environment = environment
        self._records_limit = records_limit
        self._output_mode = output_mode
        self._output_file = output_file

        now = datetime.now()
        self._min_flight_date = (now - timedelta(days=32)).date()
        self._max_flight_date = (now + timedelta(days=330)).date()
        self._flying_time_per_day = {}
        self._flying_time_min = {}
        self._skipped_invalid_records = 0

    def run(self) -> Optional[IO]:
        logger.info('Started processing flying times')

        # connect to database
        conn = psycopg2.connect(self.get_psycopg2_conn_string())
        self.fetch_records(conn)
        return self.output_flying_times()

    def process_flight_record(
        self,
        operating_from: str,
        operating_until: str,
        operating_on_days: int,
        arrival_day_shift: int,
        departure_settlement: int,
        arrival_settlement: int,
        departure_time: int,
        arrival_time: int,
        departure_tz_name: str,
        arrival_tz_name: str,
    ) -> None:
        """Processes single record from the database into the local cache
        :param operating_from: the beginning of the dates range for this flight pattern
        :param operating_until: the end of the dates range for this flight pattern
        :param operating_on_days: weekdays this flight operates within the specified dates range
        :param arrival_day_shift: 1 for overnight flights, 0 otherwise
        :param departure_settlement: departure settlement ID
        :param arrival_settlement: arrival settlement ID
        :param departure_time: local departure time as HHMM integer
        :param arrival_time: local arrival time as HHMM integer
        :param departure_tz_name: pytz name for the departure timezone
        :param arrival_tz_name: pytz name for the arrival timezone
        """
        if not operating_from or not operating_until:
            self._skipped_invalid_records += 1
            return
        if not departure_tz_name or not arrival_tz_name:
            self._skipped_invalid_records += 1
            return
        if not self.valid_time(departure_time) or not self.valid_time(arrival_time):
            self._skipped_invalid_records += 1
            return

        departure_tz = pytz.timezone(departure_tz_name)
        arrival_tz = pytz.timezone(arrival_tz_name)
        operating_on_days_str = str(operating_on_days)
        min_flight_date = self.get_min_flight_date(operating_from)
        max_flight_date = self.get_max_flight_date(operating_until)
        flight_date = min_flight_date - timedelta(days=1)
        while flight_date < max_flight_date:
            flight_date = flight_date + timedelta(days=1)
            if not str(flight_date.weekday() + 1) in operating_on_days_str:
                continue
            flying_time = self.calculate_flying_time(
                flight_date,
                arrival_day_shift,
                departure_time,
                arrival_time,
                departure_tz,
                arrival_tz,
            )
            self.update_flying_time(departure_settlement, arrival_settlement, flight_date, flying_time)

    def update_flying_time(
        self,
        departure_settlement: int,
        arrival_settlement: int,
        flight_date: datetime,
        flying_time: int,
    ) -> None:
        """Refreshes the minimum flying time between the two given settlements in the local cache
        :param departure_settlement: departure settlement ID
        :param arrival_settlement: arrival settlement ID
        :param flight_date: local departure date
        :param flying_time: the new flying time value to compare with the current minimum
        """
        if flying_time <= 0:
            logger.warn(
                'Invalid flight time between settlements %d and %d on %s: %d',
                departure_settlement,
                arrival_settlement,
                flight_date.strftime(DATE_FORMAT),
                flying_time,
            )
            return
        min_time_key = (departure_settlement, arrival_settlement)
        current_min_time = self._flying_time_min.get(min_time_key)
        if not current_min_time or flying_time < current_min_time:
            self._flying_time_min[min_time_key] = flying_time

        min_date_time_key = (departure_settlement, arrival_settlement, flight_date.strftime(DATE_FORMAT))
        current_date_min_time = self._flying_time_per_day.get(min_date_time_key)
        if not current_date_min_time or flying_time < current_date_min_time:
            self._flying_time_per_day[min_date_time_key] = flying_time

    def get_min_flight_date(self, operating_from: date) -> date:
        """Returns max(operating_from, self._min_flight_date)"""
        if not operating_from or operating_from < self._min_flight_date:
            return self._min_flight_date
        return operating_from

    def get_max_flight_date(self, operating_until: date) -> date:
        """Returns min(operating_from, self._max_flight_date) to protect us from flight patterns that lasts until 2049"""
        if not operating_until or operating_until > self._max_flight_date:
            return self._max_flight_date
        return operating_until

    @staticmethod
    def valid_time(scheduled_time: int) -> bool:
        """Validates the time value in the HHMM form"""
        return scheduled_time >= 0 and scheduled_time < 2400 and scheduled_time % 100 < 60

    def calculate_flying_time(
        self,
        flight_date: datetime,
        arrival_day_shift: int,
        departure_time: int,
        arrival_time: int,
        departure_tz: pytz.timezone,
        arrival_tz: pytz.timezone,
    ) -> int:
        """Calculates the time diff in minutes
        :param flight_date: local departure date
        :param arrival_day_shift: 1 for overnight flights, 0 otherwise
        :param departure_time: local departure time as HHMM integer
        :param arrival_time: local arrival time as HHMM integer
        :param departure_tz: pytz location for the departure timezone
        :param arrival_tz: pytz location for the arrival timezone
        """
        departure_time = self.parse_time(flight_date, 0, departure_time, departure_tz)
        arrival_time = self.parse_time(flight_date, arrival_day_shift, arrival_time, arrival_tz)
        return int((arrival_time - departure_time).total_seconds() // 60)

    def parse_time(
        self,
        flight_date: datetime,
        days_shift: int,
        day_time: int,
        location_tz: pytz.timezone,
    ) -> datetime:
        """Parses the given set of departure or arrival time params into a datetime
        :param flight_date: local departure date
        :param days_shift: 1 for overnight flights, 0 otherwise
        :param day_time: local time as HHMM integer
        :param location_tz: pytz location for the departure or arrival timezone
        """
        flight_date_to_parse = flight_date
        if days_shift:
            flight_date_to_parse = flight_date + timedelta(days=days_shift)
        dt = datetime(
            flight_date_to_parse.year,
            flight_date_to_parse.month,
            flight_date_to_parse.day,
            day_time // 100,
            day_time % 100,
            0,  # seconds
            0,  # microseconds
        )
        return location_tz.localize(dt)

    def output_flying_times(self) -> Optional[IO]:
        """Outputs local flying times cache into a binary mem file (or in debug mode, into a text file)"""
        mode = 'wt' if self._output_mode == TEXT_OUTPUT_MODE else 'wb'

        if self._output_file:
            with open(self._output_file, mode) as output:
                self._output_flying_times(output, mode)
            return None

        mem_file = io.BytesIO()
        self._output_flying_times(mem_file, mode)
        return mem_file

    def _output_flying_times(self, output: IO, mode: str) -> None:
        logger.info('Start saving flying times into an output file: %s', self._output_file)
        logger.info('Skipped invalid records: %s', self._skipped_invalid_records)

        for key, value in self._flying_time_per_day.items():
            flying_time_entry = TFlyingTime()
            flying_time_entry.DepartureSettlementId = key[0]
            flying_time_entry.ArrivalSettlementId = key[1]
            flying_time_entry.FlightDate = key[2]
            flying_time_entry.MinFlyingTime = value
            if self._output_mode == TEXT_OUTPUT_MODE:
                output.write(text_format.MessageToString(flying_time_entry, as_utf8=True))
                output.write('\n')
            else:
                write_binary_string(output, flying_time_entry.SerializeToString())

        for key, value in self._flying_time_min.items():
            flying_time_entry = TFlyingTime()
            flying_time_entry.DepartureSettlementId = key[0]
            flying_time_entry.ArrivalSettlementId = key[1]
            flying_time_entry.FlightDate = ''
            flying_time_entry.MinFlyingTime = value
            if self._output_mode == TEXT_OUTPUT_MODE:
                output.write(text_format.MessageToString(flying_time_entry, as_utf8=True))
                output.write('\n')
            else:
                write_binary_string(output, flying_time_entry.SerializeToString())

        logger.info('Saved min flying time records: %d', len(self._flying_time_min))
        logger.info('Saved min flying time per date records: %d', len(self._flying_time_per_day))
        logger.info('Done saving flying times into an output file')

    def fetch_records(self, conn) -> None:
        logger.info('Start fetching flight records from the database')

        records_count = 0
        with conn.cursor(name='flight_records') as cursor:
            cursor.execute(
                '''
                select
                    fp.operating_from,
                    fp.operating_until,
                    fp.operating_on_days,
                    fp.arrival_day_shift,
                    stdep.settlement_id,
                    starr.settlement_id,
                    fb.scheduled_departure_time,
                    fb.scheduled_arrival_time,
                    tzdep.code,
                    tzarr.code
                from
                    flight_pattern fp
                inner join
                    flight_base fb
                    on fp.flight_base_id = fb.id
                inner join
                    station_with_codes stdep
                    on fb.departure_station = stdep.id
                inner join
                    timezone tzdep
                    on stdep.time_zone_id=tzdep.id
                inner join
                    station_with_codes starr
                    on fb.arrival_station = starr.id
                inner join
                    timezone tzarr
                    on starr.time_zone_id=tzarr.id
                where not fp.is_codeshare
                union
                select
                    fp.operating_from,
                    fp.operating_until,
                    fp.operating_on_days,
                    fp.arrival_day_shift,
                    stdep.settlement_id,
                    starr.settlement_id,
                    fb.scheduled_departure_time,
                    fb.scheduled_arrival_time,
                    tzdep.code,
                    tzarr.code
                from
                    sirena_flight_pattern fp
                inner join
                    sirena_flight_base fb
                    on fp.flight_base_id = fb.id
                inner join
                    station_with_codes stdep
                    on fb.departure_station = stdep.id
                inner join
                    timezone tzdep
                    on stdep.time_zone_id=tzdep.id
                inner join
                    station_with_codes starr
                    on fb.arrival_station = starr.id
                inner join
                    timezone tzarr
                    on starr.time_zone_id=tzarr.id
                where not fp.is_codeshare
                union
                select
                    fp.operating_from,
                    fp.operating_until,
                    fp.operating_on_days,
                    fp.arrival_day_shift,
                    stdep.settlement_id,
                    starr.settlement_id,
                    fb.scheduled_departure_time,
                    fb.scheduled_arrival_time,
                    tzdep.code,
                    tzarr.code
                from
                    apm_flight_pattern fp
                inner join
                    apm_flight_base fb
                    on fp.flight_base_id = fb.id
                inner join
                    station_with_codes stdep
                    on fb.departure_station = stdep.id
                inner join
                    timezone tzdep
                    on stdep.time_zone_id=tzdep.id
                inner join
                    station_with_codes starr
                    on fb.arrival_station = starr.id
                inner join
                    timezone tzarr
                    on starr.time_zone_id=tzarr.id
                where not fp.is_codeshare
                '''
            )
            for row in cursor:
                records_count += 1
                if self._records_limit and records_count > self._records_limit:
                    break
                operating_from = row[0]
                operating_until = row[1]
                operating_on_days = row[2]
                arrival_day_shift = row[3]
                departure_settlement = row[4]
                arrival_settlement = row[5]
                departure_time = row[6]
                arrival_time = row[7]
                departure_tz = row[8]
                arrival_tz = row[9]
                self.process_flight_record(
                    operating_from,
                    operating_until,
                    operating_on_days,
                    arrival_day_shift,
                    departure_settlement,
                    arrival_settlement,
                    departure_time,
                    arrival_time,
                    departure_tz,
                    arrival_tz,
                )

        logger.info('Flight records processed: %d', records_count)
        logger.info('Done fetching flight records from the database')

    def get_psycopg2_conn_string(self) -> str:
        pgaas_cluster_id = os.getenv(ENV_PGAAS_CLUSTER_ID, 'mdb9sssbmtcje8gtvlrc')
        pgaas_database_name = os.getenv(ENV_PGAAS_DATABASE_NAME, 'shared-flights')
        pgaas_user = os.getenv(ENV_PGAAS_USER, 'avia')
        pgaas_password = os.getenv(ENV_PGAAS_PASSWORD)
        pgaas_port = os.getenv(ENV_PGAAS_PORT, '6432')

        conn_string = ' '.join(
            [
                'dbname={}'.format(pgaas_database_name),
                'user={}'.format(pgaas_user),
                'host={}'.format('c-{cluster_id}.rw.db.yandex.net'.format(cluster_id=pgaas_cluster_id)),
                'port={}'.format(pgaas_port),
                'sslmode=require',
                'password={}'.format(pgaas_password),
            ]
        )
        return conn_string
