# -*- coding: utf-8 -*-
from __future__ import unicode_literals

import tempfile
from collections import namedtuple
from contextlib import contextmanager
from datetime import date, datetime

import openpyxl
import paramiko
import six
from dateutil.relativedelta import relativedelta
from yt.wrapper import YtClient

from travel.avia.library.python.references.partner import PartnerCache
from travel.cpa.collectors.lib.collector import Collector
from travel.cpa.lib.common import with_retries
from travel.cpa.lib.lib_datetime import parse_datetime_iso
from travel.cpa.lib.lib_datetime import timestamp
from travel.cpa.lib.lib_logging import get_logger
from travel.cpa.lib.order_snapshot import KiwiAviaOrderSnapshot, OrderStatus

logger = get_logger(__name__)

SOURCE = PARTNER = 'kiwi'

CellProcessor = namedtuple('CellProcessor', ['destination', 'process'])


class KiwiCollector(Collector):
    PARTNER_NAME = 'kiwi'
    BASE_URL = 'sftp2.skypicker.com'
    PORT = 22
    SSH_TIMEOUT = 120

    DATA_FILE_TEMPLATE = 'yandexflights_daily_{date}_{date}.xlsx'

    STATUS_MAPPING = {
        'closed': OrderStatus.CONFIRMED,
        'confirmed': OrderStatus.CONFIRMED,
        'refunded': OrderStatus.CANCELLED,
    }

    def __init__(self, options):
        super(KiwiCollector, self).__init__()

        self.date_from = parse_datetime_iso(options.date_from).date()
        self.date_to = parse_datetime_iso(options.date_to).date()

        self.test_only_local_path = options.test_only_local_path
        self.ftp_host = options.base_url
        self.login = options.username
        self.ssh_key = options.ssh_key
        self.port = options.port
        self.timeout = options.timeout

        yt_client = YtClient(options.yt_proxy, options.yt_token)

        self.partner = PartnerCache(yt_client)
        self.partner_id, self.billing_order_id = self.partner.partner_id_bundle(PARTNER)

        self.get_filelist = with_retries(
            func=self._get_filelist_once,
            counter=self.metrics,
            key='collector.events.http_retry'
        )
        self.get_day_report_file = with_retries(
            func=self._get_day_report_file_once,
            counter=self.metrics,
            key='collector.events.http_retry'
        )

        # List the resulting dictionary keys to output data and the functions to get that data
        self.cell_processors = {
            'bid': CellProcessor('order_id', lambda value: six.text_type(value) if value is not None else None),
            'booking_timestamp': CellProcessor('created_at', self._parse_date),
            'price': CellProcessor('order_amount', self._to_float),
            'external_id': CellProcessor('label', self._as_is),
            'trip_type': CellProcessor('trip_type', self._trip_type),
            'status': CellProcessor('status', lambda value: self.STATUS_MAPPING.get(value)),
            'o_d': CellProcessor('_o_d', self._as_is),
        }

    @classmethod
    def configure(cls, parser):
        parser.add_argument('--base-url', default=cls.BASE_URL)
        parser.add_argument('--port', type=int, default=cls.PORT)
        parser.add_argument('-u', '--username', required=True)
        parser.add_argument('--ssh-key', required=True)
        parser.add_argument('--timeout', type=int, default=cls.SSH_TIMEOUT)
        parser.add_argument('--date-from', default=(date.today() + relativedelta(months=-4)).isoformat())
        parser.add_argument('--date-to', default=date.today().isoformat())
        parser.add_argument('--test-only-local-path')

        parser.add_argument('--yt-proxy', default='hahn')
        parser.add_argument('--yt-token', default=None)

    def _get_snapshots(self):
        for filepath in self._get_data_filepaths():
            for snapshot in self.get_day_snapshots(filepath):
                yield snapshot

    def _get_data_filepaths(self):
        if self.test_only_local_path:
            yield self.test_only_local_path
            return

        first_file = self.DATA_FILE_TEMPLATE.format(date=self.date_from.strftime('%Y-%m-%d'))
        last_file = self.DATA_FILE_TEMPLATE.format(date=self.date_to.strftime('%Y-%m-%d'))
        for filename in self.get_filelist():
            if first_file <= filename <= last_file:
                logger.info('Start parse %s', filename)
                with tempfile.NamedTemporaryFile(suffix='.xlsx') as temp:
                    self.get_day_report_file(filename, temp.name)
                    yield temp.name

    def _get_filelist_once(self):
        with self.sftp() as sftp:
            return sftp.listdir()

    def _get_day_report_file_once(self, filename, filepath):
        with self.sftp() as sftp:
            sftp.get(filename, filepath)

    def get_day_snapshots(self, filepath):
        wb = openpyxl.load_workbook(filename=filepath, read_only=True)
        for snapshot in self.get_snapshots_from_workbook(wb):
            yield snapshot

    def get_snapshots_from_workbook(self, wb):
        headers = []
        sheet = wb[wb.sheetnames[0]]
        for index, row in enumerate(sheet.rows):
            if all(cell.value is None for cell in row):
                continue

            if not headers:
                headers = [cell.value.strip() if cell.value is not None else None
                           for cell in row]
                continue

            order = {}
            for col_index, cell in enumerate(row):
                if cell.value is not None:
                    processor = self.cell_processors.get(headers[col_index])
                    if processor:
                        result = processor.process(cell.value)
                        if result is not None:
                            order[processor.destination] = result

            if order.get('created_at') is None or order.get('status') is None:
                continue

            route = order.get('_o_d')
            if route is not None and '|' in route:
                order['origin'], order['destination'] = route.split('|', 2)

            order['partner_id'] = self.partner_id
            order['billing_order_id'] = self.billing_order_id
            snapshot = KiwiAviaOrderSnapshot.from_dict(order, ignore_unknown=True)
            snapshot.update_partner_order_id(order['order_id'])
            yield snapshot

    @staticmethod
    def _as_is(value):
        return value

    @staticmethod
    def _parse_date(value):
        # sample input: 2019-07-16 15:42:26.811175+00
        if not value:
            return None

        if '+' in value:
            value = value.split('+')[0]

        return timestamp(datetime.strptime(value, '%Y-%m-%d %H:%M:%S.%f'))

    @staticmethod
    def _trip_type(value):
        return value if value in ('oneway', 'roundtrip') else None

    @staticmethod
    def _to_float(value):
        try:
            return float(value)
        except Exception:
            return 0.

    @contextmanager
    def sftp(self):
        with paramiko.SSHClient() as ssh_client:
            ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            try:
                pkey = paramiko.RSAKey.from_private_key(six.StringIO(self.ssh_key))
                ssh_client.connect(
                    hostname=self.ftp_host,
                    username=self.login,
                    pkey=pkey,
                    timeout=self.timeout,
                )
                with ssh_client.open_sftp() as sftp:
                    yield sftp
            except Exception:
                logger.exception('SFTP error')
                raise
