# coding=utf-8
from __future__ import unicode_literals

import io
import logging
from datetime import datetime
from zipfile import ZIP_DEFLATED, ZipFile

import psycopg2
from enum import Enum
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from travel.proto.dicts.rasp.transport_model_pb2 import TTransportModel

from travel.avia.library.python.backend_client import BackendClient
from travel.avia.shared_flights.diff_builder.factory import make_mds_client
from travel.avia.shared_flights.diff_builder.fetchers import (
    TimezoneFetcher,
    DesignatedCarrierFetcher,
    CarrierFetcher,
    PopularScoreFetcher,
    IATACorrectionRuleFetcher,
    StationFetcher,
    StationStatusSourceFetcher,
    FlightStatusSourceFetcher,
    FlightStatusFetcher,
    OverrideFetcher,
    BlackListRuleFetcher,
    FlightMergeRuleFetcher,
    LastImportedFetcher,
    fetch_stations,
)
from travel.avia.shared_flights.diff_builder.flights_builder import FlightsBuilder
from travel.avia.shared_flights.diff_builder.p2p_cache import P2PCache
from travel.avia.shared_flights.diff_builder.utils import ensure_string
from travel.avia.shared_flights.diff_builder.utils import write_binary_string
from travel.avia.shared_flights.lib.python.db.engine import conn_string
from travel.avia.shared_flights.lib.python.db_locks.db_lock_handle import DbLockHandle, DbLockHandleException
from travel.avia.shared_flights.lib.python.db_models.db_lock import DbLockType
from travel.avia.shared_flights.lib.python.settings import (
    PGAAS_CLUSTER_ID,
    PGAAS_DATABASE_NAME,
    PGAAS_USER,
    PGAAS_PASSWORD,
    PGAAS_PORT,
    MDS_ACCESS_KEY_ID,
    MDS_ACCESS_KEY_SECRET,
    MDS_FOLDER,
    BACKEND_HOST,
)

AVIA_SHARED_FLIGHTS_SNAPSHOT_KEY = 'AVIA_SHARED_FLIGHTS_SNAPSHOT_KEY'


class Mode(Enum):
    snapshot = 'snapshot'
    delta = 'delta'


class DiffBuilder(object):
    def __init__(self, logger: logging.Logger, solomon_push_reporter):
        self.logger = logger
        self.solomon_push_reporter = solomon_push_reporter

    def build_diff(self, mode: Mode, output_file_path: str = None, should_force: bool = False) -> None:
        self.logger.info('Building diff')

        # connect to database
        db_engine = create_engine(DiffBuilder.get_conn_string_for_work(), echo=False)
        if not db_engine:
            raise Exception('Unable to create DB engine')
        db_engine.execution_options(stream_results=True)

        session_factory = sessionmaker(bind=db_engine)
        session = session_factory()
        if not session:
            raise Exception('Unable to create DB session')

        mds_client = make_mds_client(key_id=MDS_ACCESS_KEY_ID, key_secret=MDS_ACCESS_KEY_SECRET)

        backend_client = BackendClient(BACKEND_HOST)

        if mode is Mode.snapshot:
            try:
                self.build_snapshot(session_factory, mds_client, output_file_path, should_force, backend_client)
            except DbLockHandleException as e:
                self.logger.warn('Unable to obtain snapshot lock, building snapshot is skipped: %s', e)
            return

        if mode is Mode.delta:
            try:
                self.build_delta(session_factory, mds_client, should_force)
            except DbLockHandleException as e:
                self.logger.warn('Unable to obtain delta lock, building delta is skipped: %s', e)
            return

        self.logger.info('Please specify --snapshot or --delta')

    def _create_snapshot_file(self, conn, backend_client: BackendClient) -> io.BytesIO:
        mem_file = io.BytesIO()
        with ZipFile(mem_file, mode='w', compression=ZIP_DEFLATED) as zip_file:
            name_to_fetcher = {
                'designated_carriers.pb2.bin': DesignatedCarrierFetcher(conn=conn, logger=self.logger),
                'carriers.pb2.bin': CarrierFetcher(conn=conn, logger=self.logger),
                'popular_scores.pb2.bin': PopularScoreFetcher(backend_client=backend_client, logger=self.logger),
                'iata_correction_rules.pb2.bin': IATACorrectionRuleFetcher(conn=conn, logger=self.logger),
                'timezones.pb2.bin': TimezoneFetcher(conn=conn, logger=self.logger),
                'stations_with_codes.pb2.bin': StationFetcher(conn=conn, logger=self.logger),
                'last_imported.pb2.bin': LastImportedFetcher(conn=conn, logger=self.logger),
            }
            for name, fetcher in name_to_fetcher.items():
                mem_entry = io.BytesIO()
                for data in fetcher.fetch():
                    write_binary_string(mem_entry, data)
                zip_file.writestr(name, mem_entry.getvalue())

            mem_transport_models_entry, transport_models_dict = self.fetch_transport_models_to_mem_file(conn.cursor())
            zip_file.writestr("transport_models.pb2.bin", mem_transport_models_entry.getvalue())
            del mem_transport_models_entry

            flights_builder = FlightsBuilder(self.logger)
            mem_fb_entry, flight_bases_ids, flight_bases_stations = flights_builder.fetch_flight_bases(
                conn, transport_models_dict
            )
            zip_file.writestr("flight_bases.pb2.bin", mem_fb_entry.getvalue())
            del mem_fb_entry

            p2p_cache = P2PCache(self.logger)
            mem_flight_patterns_entry = flights_builder.fetch_flight_patterns(
                conn,
                flight_bases_ids,
                flight_bases_stations,
                p2p_cache,
            )
            zip_file.writestr("flight_patterns.pb2.bin", mem_flight_patterns_entry.getvalue())
            del mem_flight_patterns_entry

            mem_p2p_cache_entry = p2p_cache.write_to_mem_file()
            zip_file.writestr("p2p_cache.pb2.bin", mem_p2p_cache_entry.getvalue())
            del mem_p2p_cache_entry
            del p2p_cache
        return mem_file

    def build_snapshot(
        self,
        session_factory,
        mds_client,
        output_file_path,
        should_force,
        backend_client: BackendClient,
    ):
        self.logger.info('Building snapshot')
        # test if another snapshot build is in progress
        timeout = 5400  # seconds
        retry_timeout = 30  # seconds
        with DbLockHandle(
            session_factory(), DbLockType.DBLOCK_BUILD_SNAPSHOT, self.logger, should_force, timeout, retry_timeout
        ):
            self.logger.info('Connecting to postgresql')
            conn = psycopg2.connect(DiffBuilder.get_psycopg2_conn_string())

            # TODO(u-jeen): verify if the snapshot exists already, don't rebuild unless --force
            self.logger.info('Start creating zip archive')
            mem_file = self._create_snapshot_file(conn, backend_client=backend_client)
            self.logger.info('Done creating zip archive')

            base_key = datetime.today().strftime('%Y-%m-%d')
            snapshot_key = 'snapshots/{}/{}/snapshot.zip'.format(MDS_FOLDER, base_key)
            self.logger.info('Start putting snapshot to {}'.format(snapshot_key))
            mds_client.put_object(
                Bucket='avia-shared-flights-dev',
                Key=snapshot_key,
                Body=mem_file.getvalue(),
            )

            if output_file_path:
                with open(output_file_path, 'wb') as output_file:
                    output_file.write(mem_file.getvalue())

            if self.solomon_push_reporter:
                self.solomon_push_reporter.set_value('diff_builder.build_snapshot.success', 1)
                self.logger.info('Sent build snapshot success sensor')

            # cache the new snapshot key into the database
            self.update_base_key(conn.cursor(), base_key)
            conn.commit()

        self.logger.info('Done building snapshot')

    def _create_delta_file(self, conn) -> io.BytesIO:
        stations = fetch_stations(conn.cursor(), self.logger)
        name_to_fetcher = {
            'station_status_source.pb2.bin': StationStatusSourceFetcher(conn=conn, logger=self.logger),
            'status_sources.pb2.bin': FlightStatusSourceFetcher(conn=conn, logger=self.logger),
            'flight_statuses.pb2.bin': FlightStatusFetcher(conn=conn, stations=stations, logger=self.logger),
            'overrides.pb2.bin': OverrideFetcher(conn=conn, logger=self.logger),
            'blacklist_rules.pb2.bin': BlackListRuleFetcher(conn=conn, stations=stations, logger=self.logger),
            'flight_merge_rules.pb2.bin': FlightMergeRuleFetcher(conn=conn, logger=self.logger),
        }

        mem_file = io.BytesIO()
        with ZipFile(mem_file, mode='w', compression=ZIP_DEFLATED) as zip_file:
            for name, fetcher in name_to_fetcher.items():
                mem_entry = io.BytesIO()
                for data in fetcher.fetch():
                    write_binary_string(mem_entry, data)
                zip_file.writestr(name, mem_entry.getvalue())
        return mem_file

    def build_delta(self, session_factory, mds_client, should_force):
        self.logger.info('Building delta')
        # test if another snapshot build is in progress
        timeout = 180  # seconds
        retry_timeout = 30  # seconds
        with DbLockHandle(
            session_factory(), DbLockType.DBLOCK_BUILD_DELTA, self.logger, should_force, timeout, retry_timeout
        ):
            self.logger.info('Connecting to postgresql')
            conn = psycopg2.connect(DiffBuilder.get_psycopg2_conn_string())

            self.logger.info('Start creating zip archive')
            mem_file = self._create_delta_file(conn)
            self.logger.info('Done creating zip archive')

            now = datetime.now()
            base_key = self.fetch_base_key(conn.cursor())
            if not base_key:
                self.logger.error('Unable to write delta: no base snapshot key in the database')
                return

            delta_key = now.strftime('%H-%M-%S')
            snapshot_key = 'snapshots/{}/{}/delta-{}.zip'.format(MDS_FOLDER, base_key, delta_key)
            self.logger.info('Start putting delta to %s', snapshot_key)
            mds_client.put_object(
                Bucket='avia-shared-flights-dev',
                Key=snapshot_key,
                Body=mem_file.getvalue(),
            )

            # set the new delta as the latest
            self.logger.info('Start resetting the latest.txt')
            mem_entry = io.BytesIO()
            write_binary_string(mem_entry, str.encode('\nbase={} delta={}'.format(base_key, delta_key), 'utf-8'))
            mds_client.put_object(
                Bucket='avia-shared-flights-dev',
                Key='snapshots/{}/latest.txt'.format(MDS_FOLDER),
                Body=mem_entry.getvalue(),
            )
            if self.solomon_push_reporter:
                self.solomon_push_reporter.set_value('diff_builder.build_delta.success', 1)
                self.logger.info('Sent build delta success sensor')

            self.logger.info('Done resetting the latest.txt')

        self.logger.info('Done building delta')

    def fetch_transport_models_to_mem_file(self, cursor):
        self.logger.info('Start fetching transport models')
        transport_models = {}
        mem_entry = io.BytesIO()
        # iterate through transport models
        cursor.execute(
            '''
            select
                id,
                title,
                title_en,
                code,
                code_en,
                cargo,
                transport_subtype_id,
                propeller_flight,
                plane_body_type
            from
                transport_model
            order by
                id;
            '''
        )
        for row in cursor:
            transport_model = TTransportModel()
            transport_model.Id = row[0]
            transport_model.Title = ensure_string(row[1])
            transport_model.TitleEn = ensure_string(row[2])
            transport_model.Code = ensure_string(row[3])
            transport_model.CodeEn = ensure_string(row[4])
            transport_model.Cargo = row[5] or False
            transport_model.TransportSubtypeId = row[6] if row[6] else 0
            transport_model.PropellerFlight = row[7] if row[7] else False
            transport_model.PlaneBodyType = ensure_string(row[8])
            write_binary_string(mem_entry, transport_model.SerializeToString())

            if transport_model.Code:
                existing_id_by_code = transport_models.get(transport_model.Code)
                if existing_id_by_code and existing_id_by_code != transport_model.Id:
                    self.logger.error(
                        'Error: transport model code %s is mapped to multiple ids %s, %s',
                        transport_model.Code,
                        existing_id_by_code,
                        transport_model.Id,
                    )
                    continue
                transport_models[transport_model.Code] = transport_model.Id
            if transport_model.CodeEn:
                existing_id_by_code_en = transport_models.get(transport_model.CodeEn)
                if existing_id_by_code_en and existing_id_by_code_en != transport_model.Id:
                    self.logger.error(
                        'Error: transport model code_en %s is mapped to multiple ids %s, %s',
                        transport_model.CodeEn,
                        existing_id_by_code_en,
                        transport_model.Id,
                    )
                    continue
                transport_models[transport_model.CodeEn] = transport_model.Id

        self.logger.info('Done fetching transport models')
        return mem_entry, transport_models

    def update_base_key(self, cursor, new_base_key):
        self.logger.info('Start updating snapshot base key in the database')
        cursor.execute(
            '''
            update last_imported_info
                set imported_resource_id = %s,
                    updated_at = CURRENT_TIMESTAMP
                where resource_type = %s
            ''',
            (new_base_key.replace('-', ''), AVIA_SHARED_FLIGHTS_SNAPSHOT_KEY),
        )

        self.logger.info('Done updating snapshot base key in the database')

    def fetch_base_key(self, cursor):
        self.logger.info('Fetching snapshot key')
        base_key = None
        cursor.execute(
            '''
            select
                imported_resource_id
            from
                last_imported_info
            where resource_type = %s
            limit 1
            ''',
            (AVIA_SHARED_FLIGHTS_SNAPSHOT_KEY,),
        )
        for row in cursor:
            base_key = str(row[0])

        if len(base_key) != 8:
            return None

        self.logger.info('Done fetching snapshot key')
        return base_key[0:4] + '-' + base_key[4:6] + '-' + base_key[6:8]

    @staticmethod
    def get_conn_string_for_work():
        conn_string_for_work = conn_string(
            user=PGAAS_USER,
            password=PGAAS_PASSWORD,
            database=PGAAS_DATABASE_NAME,
            cluster_id=PGAAS_CLUSTER_ID,
            port=PGAAS_PORT,
        )
        return conn_string_for_work

    @staticmethod
    def get_psycopg2_conn_string():
        conn_string = ' '.join(
            [
                f'dbname={PGAAS_DATABASE_NAME}',
                f'user={PGAAS_USER}',
                f'host=c-{PGAAS_CLUSTER_ID}.rw.db.yandex.net',
                f'port={PGAAS_PORT}',
                'sslmode=require',
                f'password={PGAAS_PASSWORD}',
            ]
        )

        return conn_string
