# coding: utf-8
import logging
import tempfile
import urllib
import zlib
import json
from collections import defaultdict

from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment

from sandbox.projects.avia.base import AviaBaseTask


class FlightExtrasClient(object):
    """"
    Client for sending aircraft models to Flight Extras.
    Group by records by company code and flight number.
    In the end of using flush method should be invoked.
    """
    def __init__(self, base_url, source, logger, max_capacity=10000):
        # type: (str, str, logging.Logger, int) -> None
        from requests import Session
        from requests.packages.urllib3.util.retry import Retry
        from requests.adapters import HTTPAdapter

        self._base_url = base_url
        self._source = source
        self._logger = logger
        self._max_capacity = max_capacity
        self._capacity = 0
        self._cache = self._get_empty_cache()  # (Code, Number) -> [(DepartureDay, Aircraft)]

        self._session = Session()
        self._session.mount('http://', HTTPAdapter(max_retries=Retry(
            total=10,
            read=10,
            connect=10,
            method_whitelist=frozenset(['HEAD', 'TRACE', 'GET', 'PUT', 'OPTIONS', 'DELETE', 'POST']),
            backoff_factor=1,
            status_forcelist=(500, 502, 504, 429),
        )))

    @staticmethod
    def _get_empty_cache():
        return defaultdict(list)

    def add(self, code, number, departure_day, aircraft):
        # type: (str, str, str, str) -> None
        self._cache[(code, number)].append((departure_day, aircraft))
        self._capacity += 1
        if self._capacity > self._max_capacity:
            self.flush()

    def flush(self):
        # type () -> None
        self._logger.info('Start sending %d fligths', self._capacity)
        for (code, number), flight_info in self._cache.iteritems():
            self._send(code, number, flight_info)

        self._capacity = 0
        self._cache = self._get_empty_cache()

    def _send(self, code, number, flight_info):
        self._logger.info('Sending for code=%s number=%s %d items', code, number, len(flight_info))
        data = {
            'source': self._source,
            'data': [
                {
                    'departure_day': departure_day,
                    'aircraft': aircraft,
                }
                for departure_day, aircraft in flight_info
            ]
        }

        self._logger.debug('Sending %r', data)
        response = self._session.post(
            u'{base_url}/flight/{code}/{number}/aircraft/'.format(
                base_url=self._base_url,
                code=code,
                number=number,
            ),
            json=data,
        )
        response.raise_for_status()
        self._logger.info('Status %r, content: %s', response.status_code, response.content)


class AviaFlightStorageImportRaspFlights(AviaBaseTask):
    """ Import flights to flight-storage from rasp """

    DEPARTURE_DATETIME_FORMAT = '%Y-%m-%dT%H:%M:%S'

    class Requirements(sdk2.Task.Requirements):
        environments = (
            PipEnvironment('requests'),
            PipEnvironment('yandex-yt', version='0.10.8'),
            PipEnvironment('yandex-yt-yson-bindings-skynet', version='0.3.32-0'),
        )

    class Parameters(sdk2.Task.Parameters):

        reference_yt_user = 'robot-avia-api-pilot'
        reference_yt_proxy = 'hahn'
        reference_yt_path = '//home/rasp/reference'

        flights_url = sdk2.parameters.String(
            'Flights url',
            default='http://rasp-bucket.s3.mds.yandex.net/export/avia/flights.dat',
            required=True,
        )

        references_url = sdk2.parameters.String(
            'References url',
            default='http://rasp-bucket.s3.mds.yandex.net/export/avia/references.dat',
            required=True,
        )

        send_to_flight_extras = sdk2.parameters.Bool(
            'Send aircraft models to Flight-Extras',
            default=True,
        )

        flight_extras_url = sdk2.parameters.String(
            'Flight Extras url',
            default='http://flight-extras.testing.avia.yandex.net/v1',
        )

        partner_code = sdk2.parameters.String('Partner code', default='_travel.rasp_', required=True)

    def _load_references(self):
        # type: () -> dict
        logging.info('Loading references ...')
        temp = tempfile.NamedTemporaryFile().name
        self._download_file(self.Parameters.references_url, temp)
        with open(temp) as fp:
            references = json.load(fp)

        logging.info('Loaded %d companies', len(references['companies']))
        logging.info('Loaded %d stations', len(references['stations']))
        logging.info('Loaded %d transport models', len(references['transport_models']))

        return references

    def _load_yt_references(self):
        # type: () -> dict
        import yt.wrapper as yt
        yt_client = yt.YtClient(
            proxy=self.Parameters.reference_yt_proxy,
            token=sdk2.Vault.data(self.Parameters.reference_yt_user, 'YT_TOKEN'),
        )

        references = {
            'stations': {},
            'companies': {},
        }

        yt_format = yt.YsonFormat()

        logging.info('Precache companies ...')
        for record in yt_client.read_table(self.Parameters.reference_yt_path + '/company', format=yt_format):
            references['companies'][record['id']] = {
                'iata': record['iata'] if record['iata'] else None,
                'sirena': record['sirena'] if record['sirena'] else None,
            }
        logging.info('Precached %d companies', len(references['companies']))

        logging.info('Precache stations ...')
        for record in yt_client.read_table(self.Parameters.reference_yt_path + '/station', format=yt_format):
            references['stations'][record['id'][1:]] = {
                'iata': record['iata'] if record['iata'] else None,
                'sirena': record['sirena'] if record['sirena'] else None,
                'time_zone': record['time_zone'] if record['time_zone'] else None,
            }
        logging.info('Precached %d stations', len(references['stations']))

        return references

    @staticmethod
    def _download_file(source, destination):
        # type: (sdk2.parameters.String, str) -> None
        logging.info('Download data from url "%s" and decompressing ...', source)
        decompresser = zlib.decompressobj()
        response = urllib.urlopen(source)

        chunk_size = 50 * 1024 * 1024
        logging.info('Destination file: %s', destination)
        with open(destination, 'w') as fp:
            while True:
                chunk = response.read(chunk_size)
                if not chunk:
                    break
                fp.write(decompresser.decompress(chunk))

    @staticmethod
    def _get_code(d):
        return d['iata'] or d['sirena']

    @staticmethod
    def _compare_references(references, yt_references):
        # type: (dict, dict) -> None
        for company_id in references['companies']:
            if company_id not in yt_references['companies']:
                logging.warning('Unknown company "%d": %s', company_id, references['companies'][company_id])
            elif references['companies'][company_id] != yt_references['companies'][company_id]:
                logging.warning(
                    'Invalid company (%s) vs (%s)',
                    references['companies'][company_id],
                    yt_references['companies'][company_id]
                )

        for station_id in references['stations']:
            if station_id not in yt_references['stations']:
                logging.warning('Unknown station "%d": %s', station_id, references['stations'][station_id])
            elif references['stations'][station_id] != yt_references['stations'][station_id]:
                logging.warning(
                    'Invalid station (%s) vs (%s)',
                    references['stations'][station_id],
                    yt_references['stations'][station_id]
                )

    def on_execute(self):
        yt_references = self._load_yt_references()
        references = self._load_references()

        self._compare_references(references, yt_references)

        temp = tempfile.NamedTemporaryFile().name
        self._download_file(self.Parameters.flights_url, temp)

        logging.info('Start import rasp flight from: %s', self.Parameters.flights_url)

        flight_extras_client = None
        if self.Parameters.send_to_flight_extras:
            flight_extras_client = FlightExtrasClient(
                self.Parameters.flight_extras_url,
                self.Parameters.partner_code,
                logging,
            )

        with open(temp) as fp:
            for line in fp:
                try:
                    f = json.loads(line.strip())
                except Exception, e:
                    logging.exception('Error parsing (%s): %s', e, line)
                    continue

                company_id = str(f['company_id'])
                try:
                    reference_item = references['companies'][company_id]
                except KeyError:
                    logging.error('Company_id {} not found for flight {}'.format(company_id, f))
                    continue

                airline_code = self._get_code(reference_item)

                # For serialization reasons, model_id stored in reference file as string
                aircraft = references['transport_models'].get(str(f['model_id']))
                if aircraft is not None and flight_extras_client:
                    flight_extras_client.add(
                        airline_code,
                        f['number'].split(' ')[-1],
                        f['departure_day'],
                        aircraft['code_en'],
                    )
                else:
                    logging.warning('No model info for model with id %s', f['model_id'])

        if flight_extras_client:
            flight_extras_client.flush()
