# -*- coding: utf-8 -*-
import logging

import re
import time as os_time
import httplib
import socket
import urllib2
from collections import OrderedDict
from copy import copy, deepcopy
from threading import Thread
from datetime import datetime
from itertools import takewhile

from django.core.cache import cache
from django.conf import settings
from django.db.models import Q

from common.models.geo import Point, StationCode, StationExpressAlias, Station
from common.models.schedule import RThread
from common.models.transport import TransportType, TrainPseudoStationMap
from common.models_utils import fetch_related
from common.utils import tracer
from common.utils.blablacar_utils import BLABLACAR_ALL_DAYS_DATE
from common.utils.caching import global_cache_set, global_cache_add
from common.utils.date import MSK_TZ
from common.utils.http import urlopen
from common.views.tariffs import DisplayInfo
from route_search.models import ZNodeRoute2
from route_search.shortcuts import find


tt_plane_id = TransportType.PLANE_ID
tt_train_id = TransportType.TRAIN_ID

JSON_DATE_FMT = '%Y-%m-%d'
# Тип(seat, tariff, ufs), результат (success, timeout, error), взято из кеша(true, false)
# время работы, описание
STAT_INFO = u'%(type)s\t%(result)s\tcached=%(cached)s\t%(time)s\t%(description)s'

# TODO: ROUNDTRIP_PRICE_FORMAT if needed

log = logging.getLogger('rasp.seat_price')
dump_prices_log = logging.getLogger('rasp.order.dump_seat_price')


class Query(object):
    segments = []

    is_order = False
    socket_timeout = settings.TARIFF_SOCKET_TIMEOUT
    initial_point_from = None
    initial_point_to = None
    cache_timeout = 1 * 60 * 60,  # по умолчанию 1 час

    def __init__(self, **kwargs):
        for k, v in kwargs.iteritems():
            setattr(self, k, v)

    def copy(self):
        return copy(self)


class ReplyInfo(object):
    def __init__(self, info):
        self.by_key = dict((key, KeyReplyInfo(value)) for key, value in info.items())

    def get_reply_timestamp(self):
        ts = None

        for key_info in self.by_key.values():
            for supplier_reply in key_info.by_supplier.values():
                ts = max(ts,
                         supplier_reply.seats,
                         supplier_reply.tariffs,
                         supplier_reply.roundtrip)

        return ts

    def get_incomplete_keys(self):
        incomplete_keys = []

        for key, info in self.by_key.items():
            if not info.complete:
                incomplete_keys.append(key)

        return incomplete_keys

    def __repr__(self):
        return "<ReplyInfo by_key=%r>" % self.by_key


class KeyReplyInfo(object):
    def __init__(self, info):
        self.by_supplier = info

    @property
    def complete(self):
        return all(i.complete for i in self.by_supplier.values())

    @property
    def seats_success(self):
        return any(i.seats_success for i in self.by_supplier.values())


def segment_data(segment):
    try:
        return segment.data
    except AttributeError:
        return "%s%s" % (segment.departure.strftime("%H%M"), segment.arrival.strftime("%H%M"))


class DirectionAjaxInfo(object):
    def __init__(self, point_from, point_to, suppliers_reply_info, segments, early_border, late_border):
        self.direction_key = self.encode_direction_key(point_from, point_to)
        self.timestamp = suppliers_reply_info.get_reply_timestamp()
        self.request_keys = suppliers_reply_info.get_incomplete_keys()
        self.no_data_no_seats = []
        self.routes = dict(
            (s.info.route_key, segment_data(s))
            for s in segments
            if s.info
        )

        self.range = early_border and (early_border, late_border)

    def __json__(self):
        d = {
            'd': self.direction_key,
            'k': self.request_keys,
            'ts': self.timestamp,
            'ndns': self.no_data_no_seats,
            'r': self.routes,
        }

        if self.range:
            d['b'] = [b.strftime("%Y-%m-%d %H:%M") for b in self.range]

        return d

    @staticmethod
    def encode_direction_key(point_from, point_to):
        components = []

        for point in (point_from, point_to):
            components.append(point.point_key)

        return '-'.join(components)

    @classmethod
    def decode_direction_key(cls, key):
        components = key.split('-')

        point_from = Point.get_by_key(components[0])

        point_to = Point.get_by_key(components[1])

        return point_from, point_to


class AllSuppliersTariffInfo(object):
    info_template = {}

    def __init__(self, departure_date, segment=None, segment_key=None):
        if segment and segment_key is None:
            segment_key = self.get_segment_key(segment)

        self.segment_key = segment_key
        self.departure_date = departure_date
        self.by_supplier = {}

    @property
    def has_info(self):
        return bool(self.by_supplier)

    @classmethod
    def get_segment_key(cls, segment):
        if segment.t_type.code != 'bus':
            return segment.number.replace(' ', '-')

        segment_key = segment.number or segment.thread and segment.thread.hidden_number or ''

        # для сопоставления рейсов можно (но аккуратно, только для некоторых партнеров)
        # использовать время отправления в качестве ключа
        if not segment_key and hasattr(segment, 'supplier_code'):
            if segment.supplier_code == 'swdfactory':
                segment_key = segment.departure.strftime('time%H%M')

        return segment_key.replace(' ', '-')

    @property
    def route_key(self):
        return "%s-%s" % (self.segment_key, self.departure_date.strftime("%m%d"))

    def add_ti_info(self, supplier, ti_info):

        assert ti_info is not None

        old_ti_info = self.by_supplier.get(supplier)
        if old_ti_info:
            old_ti_info.update(ti_info)
        else:
            self.by_supplier[supplier] = ti_info


class SupplierReplyTime(object):
    seats = None
    seats_reason = None
    tariffs = None
    tariffs_reason = None
    roundtrip = None
    roundtrip_reason = None
    display_fields = ['seats', 'seats_reason', 'tariffs', 'tariffs_reason',
                      'roundtrip', 'roundtrip_reason']
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

    def __repr__(self):
        result = u"%s:\n" % self.__class__.__name__
        result += u"\n".join(u"%s: %r" % (f, getattr(self, f))
                             for f in self.display_fields)
        return result.encode('utf8')


    @property
    def complete(self):
        return 'timeout' not in [self.seats_reason, self.tariffs_reason, self.roundtrip_reason]

    @property
    def seats_success(self):
        return self.seats if self.seats_reason == 'success' else None


class TariffInfo(object):
    supplier = None
    route_number = None
    seats = None
    tariffs = None
    roundtrip_tariffs = None
    roundtrip_seats = None
    seats_time = None
    tariffs_time = None
    roundtrip_tariffs_time = None
    roundtrip_seats_time = None
    train_info = None
    train_info_time = None
    et_possible = False
    data_fields = ['seats', 'tariffs', 'roundtrip_tariffs', 'roundtrip_seats', 'train_info']
    time_fields = ['seats_time', 'tariffs_time',
                   'roundtrip_tariffs_time', 'roundtrip_seats_time', 'train_info_time']
    out_fields = data_fields

    def __init__(self, route_number, seats=None, tariffs=None,
                 roundtrip_tariffs=None, roundtrip_seats=None):
        self.route_number = route_number
        self.seats = seats
        self.tariffs = tariffs
        self.roundtrip_tariffs = roundtrip_tariffs
        self.roundtrip_seats = roundtrip_seats
        self.order_info = {}

    def update(self, ti_info):
        for field_name in self.data_fields + self.time_fields:
            self_attr = getattr(self, field_name)
            ti_info_attr = getattr(ti_info, field_name)
            if self_attr is None and \
                ti_info_attr is not None:
                setattr(self, field_name, ti_info_attr)

    def set_time(self, lmt):
        for field_name in self.data_fields:
            if getattr(self, field_name) is not None:
                setattr(self, field_name + '_time', lmt)

    def __repr__(self):
        template = u"""<%(class)s:
supplier: %(supplier)s
route_number: %(route_number)s
et_possible: %(et_possible)s
""" + u"\n".join('%s: %%(%s)s' % (f, f) for f in self.out_fields + self.time_fields) + \
"\n>"
        params = {}
        params['class'] = self.__class__.__name__
        params['supplier'] = self.supplier
        params['route_number'] = self.route_number
        for field_name in self.out_fields:
            params[field_name] = repr(getattr(self, field_name)).decode('utf8', 'ignore')
        for field_name in self.time_fields:
            params[field_name] = repr(getattr(self, field_name)).decode('utf8', 'ignore')
        params['tariffs'] = repr(self.tariffs)
        params['roundtrip_tariffs'] = repr(self.roundtrip_tariffs)
        params['roundtrip_seats'] = repr(self.roundtrip_seats)
        params['et_possible'] = self.et_possible

        for k in params:
            params[k] = unicode(params[k])

        return (template % params).encode('utf8')


class Result(object):
    type = None
    supplier = None
    tariffinfo_class = TariffInfo
    data_types = []
    key_of_delay = None
    query = None

    def __init__(self, query, data, reason, error_text=None, cached=False):
        self.query = query
        self.data = data
        self.reason = reason
        if reason != 'timeout':
            self.lmt = os_time.time()
        else:
            self.lmt = None

        # Добавляем в места цены информацию о поставщике и времени извлечения.
        if reason == 'success' and data:
            for tariff_info in data.values():
                tariff_info.supplier = self.supplier
                tariff_info.set_time(self.lmt)

        self.cached = cached
        self.error_text = error_text

        if reason == 'success':
            self.on_success()

    def __repr__(self):
        template = u"""<Result:
class: %(class_name)s
data: %(data)s
reason: %(reason)s
cached: %(cached)s
error_text: %(error_text)s>"""
        params = {}
        params['class_name'] = self.__class__.__name__
        params.update(self.__dict__)
        params['data'] = repr(self.data).decode("utf8")
        return (template % params).encode('utf8')

    def get_key(self):
        return self.key_of_delay + self.query.date.strftime(JSON_DATE_FMT)

    def log_me(self, start, message):
        log.info(STAT_INFO, {
            'type': self.type,
            'result': self.reason,
            'cached': self.cached,
            'time': os_time.time() - start,
            'description': message
        })

    def update_segments(self, segments, reply_time):
        if self.reason == 'success':
            for s in segments:
                ti_info = self.data.get(s.info.segment_key, None)

                if ti_info:
                    s.info.add_ti_info(self.supplier, ti_info)

        for field in self.data_types:
            setattr(reply_time, field, self.lmt)
            setattr(reply_time, field + '_reason', self.reason)

    def on_success(self):
        pass

    def on_error(self, segment):
        pass

    @classmethod
    def get_info_class(cls):
        return cls.tariffinfo_class

    def timeout_or_none_in_cache(self, key):
        result = cache.get(key)
        return result is None or result.reason == 'timeout'

    def set_cache_if_empty(self, key, timeout):
        global_cache_add(key, self, timeout)

    def update_cache(self, key, timeout):
        global_cache_set(key, self, timeout)

    def cache_me(self, key, timeout):
        # В кеше данные должны лежать с флагом cached
        self.cached = True
        if self.timeout_or_none_in_cache(key):
            if self.reason != 'timeout':
                self.update_cache(key, timeout)
            else:
                self.set_cache_if_empty(key, timeout)

            log.info(u'Положили в кэш %s %s на %.3f часа %s', self.type,
                     self.reason, timeout / 3600.0, key)

            if self.reason == 'success' and self.data:
                self.dump_prices()

        else:
            log.warning(u"В кеше %s уже лежат ранее полученные данные",
                        key)
        # Это первое извлечение
        self.cached = False

    def dump_prices(self):
        timestamp = datetime.today()
        timezone = MSK_TZ.localize(timestamp).strftime('%z')
        point_from = self.query.point_from
        point_to = self.query.point_to
        date_forward = self.query.date.strftime('%Y-%m-%d')
        date_backward = None
        for route_number, tariffinfo in self.data.items():
            params= OrderedDict((
                ('tskv_format', 'rasp-tariffs-log'),
                ('timestamp', datetime.today().strftime("%Y-%m-%d %H:%M:%S")),
                ('timezone', timezone),
                ('partner', self.supplier),
                ('type', self.key_of_delay.split(u'_')[0]),
                ('date_forward', date_forward),
                ('date_backward', date_backward),
                ('object_from_type', point_from.__class__.__name__),
                ('object_from_id', point_from.id),
                ('object_from_title', point_from.title),
                ('object_to_type', point_to.__class__.__name__),
                ('object_to_id', point_to.id),
                ('object_to_title', point_to.title),
                ('route_number', route_number)
            ))

            by_class = {}
            if tariffinfo.seats:
                for klass, seats in tariffinfo.seats.items():
                    by_class.setdefault(klass, {})['seats'] = seats

            if tariffinfo.tariffs:
                for klass, tariff in tariffinfo.tariffs.items():
                    by_class.setdefault(klass, {})['tariff'] = tariff.get('price')

            for klass, info in by_class.items():
                if 'seats' in info:
                    params[u"class_" + klass + u"_seats"] = info['seats']
                if 'tariff' in info:
                    params[u"class_" + klass + u"_tariff"] = info['tariff']

            message = u"tskv\t" + u"\t".join(u"%s=%s" % (key, value) for key, value in params.items())

            dump_prices_log.info(message)

class Fetcher(object):
    ask_me = False
    thread_class = None

    result_class = None
    roundtrip_result_class = None

    supplier = None

    def __init__(self, query):
        self.query = query

        self.cache_key = self.get_cache_key(self.query)

    @classmethod
    def can_be_asked(cls):
        return settings.ALWAYS_ASK_ALL or cls.ask_me

    @classmethod
    def is_suitable(cls, query):
        return True

    @classmethod
    def make_fake(cls, query):
        routes = {}
        cache_key = cls.get_cache_key(query)
        result_class = cls.result_class
        result = result_class(query, routes, 'success')
        result.cache_me(cache_key, settings.SP_FAKE_TIMEOUT)

    @classmethod
    def get_cache_key(cls, query):
        raise NotImplementedError

    @classmethod
    def get_cache(cls, query):
        return cache.get(cls.get_cache_key(query))

    def put_timeout_result(self):
        self.result = self.result_class(self.query, None, 'timeout')
        self.result.cache_me(self.cache_key, settings.TARIFF_SUPPLIERWAIT_TIMEOUT)

    def fetch(self):
        self.result = cache.get(self.cache_key)

        if self.result is None:
            self.put_timeout_result()

            thread = self.thread_class(self.query, self.cache_key)
            thread.start()

            return self.result
        else:
            log.debug(u"Get cached: %s", self.cache_key)
            return self.result

    def compute(self):
        self.put_timeout_result()

        thread = self.thread_class(self.query, self.cache_key)
        thread.start()

        return self.result


class BaseFetcherThread(Thread):
    result_class = None
    roundtrip_result_class = None

    supplier = None

    def __init__(self, query, cache_key):
        self.query = query
        self.cache_key = cache_key
        self.ti_class = self.result_class.get_info_class()
        Thread.__init__(self)

    def get_data(self):
        raise NotImplementedError

    def parse_data(self):
        raise NotImplementedError

    def error(self, message, timeout=settings.SP_ERROR_TIMEOUT):
        self.result = self.result_class(self.query, None, 'error', message)
        self.result.cache_me(self.cache_key, timeout)

    def run(self, *args, **kwargs):
        try:
            start = os_time.time()
            self.get_data()
            self.parse_data()
            log.info(u'%s fetch time %.3f', self.cache_key, os_time.time() - start)
        except (httplib.HTTPException, socket.gaierror, urllib2.URLError), e:
            log.exception(u'Ошибка протокола при получении данных от %s: %s',
                          self.supplier, unicode(e))
            self.error(unicode(e))

        except socket.timeout, e:
            log.error(u"Не дождались ответа от %s", self.supplier)
            self.error(unicode(e))

        except Exception, e:
            log.exception(u'Незвестная ошибка при получении или разборе ответа от %s: %s',
                          self.supplier, unicode(e))
            self.error(unicode(e))


def http_retrieve(request):
    sslerror_count = 0

    while True:
        try:
            return urlopen(request, timeout=settings.TARIFF_SOCKET_TIMEOUT)

        except (httplib.HTTPException, socket.gaierror), e:
            log.exception(u'Ошибка при получении данных: %s' % e)
            raise

        except socket.sslerror:
            if sslerror_count < 3:
                sslerror_count += 1
                log.exception(u'Получили ошибку SSL пробуем еще раз')
                continue
            else:
                log.exception(u'Получили больше 3-х ошибок SSL')
                raise

        except socket.timeout:
            log.error(u"Не дождались ответа от UFS")
            raise

        except:
            log.exception(u"Неизвестная ошибка")
            raise


class ExtraTrainSegment(object):
    EXPRESS_RE = re.compile('^(8\d\d|7\d\d\d)\D*$')

    def __init__(self, t_type, number, supplier, ti_info, date_):
        self.t_type = t_type

        self.number = self.ticket_number = number

        self.is_express = bool(self.EXPRESS_RE.match(self.number))

        self.thread = None

        train_info = ti_info.train_info

        self.first_station = train_info['start_title_from']
        self.last_station = train_info['end_title_to']

        self.first_station_code = train_info.get('first_station_code')
        self.last_station_code = train_info.get('last_station_code')

        self.msk_departure = train_info['msk_departure']
        self.msk_arrival = train_info['msk_arrival']
        self.departure = train_info['departure']
        self.arrival = train_info['arrival']

        self.duration = train_info['duration']

        self.station_from = train_info['station_from']
        self.station_to = train_info['station_to']

        self.rtstation_from = None
        self.rtstation_to = None

        self.company = None

        self.t_model = None

        self.gone = False

        self.display_info = DisplayInfo()

        self.display_t_code = 'train'

        self.info = AllSuppliersTariffInfo(date_, self)
        self.info.add_ti_info(supplier, ti_info)

        self.is_deluxe = train_info.get('is_deluxe')

    def L_title(self):
        return self.title

    def link(self, point_from, point_to):
        return None

    @classmethod
    def limit(cls, segments, borders):
        if borders:
            early_border, late_border = borders

            return dict((k, s) for k, s in segments.iteritems() if early_border <= s.departure < late_border or s.departure == BLABLACAR_ALL_DAYS_DATE)

        return {}

    @classmethod
    @tracer.wrap
    def fill_titles(cls, segments):
        cls_segments = [s for s in segments.values() if isinstance(s, cls)]

        short_names = set()
        express_codes = set()

        for s in cls_segments:
            short_names.add(s.first_station)
            short_names.add(s.last_station)

            if s.first_station_code:
                express_codes.add(s.first_station_code)

            if s.last_station_code:
                express_codes.add(s.last_station_code)

        name_to_station = dict(
            (ea.alias, ea.station)
            for ea in StationExpressAlias.objects.filter(alias__in=short_names).select_related('station')
        )

        code_to_station = dict(
            (sc.code, sc.station)
            for sc in StationCode.objects.filter(
                system__code='express',
                code__in=express_codes,
                ).select_related('station')
            )

        for s in cls_segments:
            first = s.first_station_code and code_to_station.get(s.first_station_code) or name_to_station.get(s.first_station)
            last = s.last_station_code and code_to_station.get(s.last_station_code) or name_to_station.get(s.last_station)

            if first:
                if first.settlement and not first.not_generalize:
                    s.first_station = first.settlement.title
                else:
                    s.first_station = first.title

            if last:
                if last.settlement and not last.not_generalize:
                    s.last_station = last.settlement.title
                else:
                    s.last_station = last.title

            s.title = "%s - %s" % (s.first_station, s.last_station)

    @classmethod
    @tracer.wrap
    def correct_stations(cls, segments, point_from, point_to, uri=None):
        cls_segments = [s for s in segments.values() if isinstance(s, cls)]

        need_correction = set()
        bad = set()

        for s in cls_segments:
            if s.station_from.majority.code == 'express_fake' or \
               s.station_to.majority.code == 'express_fake':
                need_correction.add(s)

        known_threads = RThread.objects.filter(number__in=[s.number for s in need_correction])

        known_numbers = set(t.number for t in known_threads)

        records = ZNodeRoute2.objects.filter(**{
            '%s_from' % point_from.type: point_from,
            '%s_to' % point_to.type: point_to,
            'thread__in': known_threads,
            }).select_related('thread')

        station_ids = set()

        for r in records:
            station_ids.add(r.station_from_id)
            station_ids.add(r.station_to_id)

        stations = Station.objects.in_bulk(station_ids)

        for r in records:
            r.station_from = stations.get(r.station_from_id)
            r.station_to = stations.get(r.station_to_id)

        number_to_stations = dict((r.thread.number, (r.station_from, r.station_to)) for r in records)

        for s in list(need_correction):
            try:
                s.station_from, s.station_to = number_to_stations[s.number]

                need_correction.remove(s)
            except KeyError:
                if s.number in known_numbers:
                    # Те рейсы, которые есть в базе, но не скорректировались по нодероуту,
                    # в выдачу попадать не должны.
                    bad.add(s)

                    # Удаляем из списка, чтобы дальше с ними не работать
                    need_correction.remove(s)

        pseudo_stations = set()

        for s in need_correction:
            if s.station_from.majority.code == 'express_fake':
                pseudo_stations.add(s.station_from)

            if s.station_to.majority.code == 'express_fake':
                pseudo_stations.add(s.station_to)

        mappings = TrainPseudoStationMap.objects.filter(
            Q(number__in=[s.number for s in need_correction]) | Q(pseudo_station__in=pseudo_stations)
            )

        fetch_related(mappings, 'station')

        corrections = {}

        for mapping in mappings:
            corrections[mapping.number, mapping.pseudo_station_id] = mapping.station

        unknown_pseudo_mappings = set()

        for s in need_correction:
            if s.station_from.majority.code == 'express_fake':
                try:
                    s.station_from = corrections[s.number, s.station_from.id]
                except KeyError:
                    unknown_pseudo_mappings.add((s.number, s.station_from))
                    bad.add(s)

            if s.station_to.majority.code == 'express_fake':
                try:
                    s.station_to = corrections[s.number, s.station_to.id]
                except KeyError:
                    unknown_pseudo_mappings.add((s.number, s.station_to))
                    bad.add(s)

            if point_from not in [s.station_from, s.station_from.settlement]:
                bad.add(s)

            if point_to not in [s.station_to, s.station_to.settlement]:
                bad.add(s)

        if unknown_pseudo_mappings:
            log = logging.getLogger('rasp.tariffs.express_subst')

            log.info(u'Ссылка: %s' % (uri or u'неизвестно'))

            for number, station in unknown_pseudo_mappings:
                log.warning(u'Не найдена замена псевдостанции: %s %s (%s)' % (number, station.id, station.title))

        for s in bad:
            del segments[s.info.route_key]

    def copy(self):
        copy_of_self = copy(self)

        copy_of_self.info = deepcopy(self.info)
        copy_of_self.display_info = deepcopy(self.display_info)

        return copy_of_self


class TrainResult(Result):
    type = "ufs_seat_tariffs"
    supplier = "ufs"
    key_of_delay = "train_"
    data_types = ['seats', 'tariffs']

    @tracer.wrap
    def on_success(self):
        self.extra_segments = {}
        result_segments = {}

        for number, ti_info in self.data.items():
            if ti_info.train_info:
                rs = ExtraTrainSegment(self.query.t_type, number, self.supplier, ti_info, self.query.date)

                result_segments[rs.number] = rs

        if not result_segments:
            return

        departures = [s.departure for s in result_segments.values()]

        max_departure = max(departures)
        min_departure = min(departures)

        base_segments = takewhile(
            lambda s: s.departure <= max_departure,
            find(self.query.initial_point_from, self.query.initial_point_to, min_departure.date(),
                 self.query.t_type)
        )

        base_segments = filter(lambda s: s.departure >= min_departure, base_segments)

        for bs in base_segments:
            if not result_segments:
                break

            rs = result_segments.get(bs.number)

            if rs and rs.departure == bs.departure and rs.arrival == bs.arrival:
                del result_segments[bs.number]

        for number, rs in result_segments.items():
            log.debug(u"%s %s %s", number, rs.departure, rs.arrival)

        self.extra_segments = dict((rs.info.route_key, rs) for rs in result_segments.values())
