# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import six

from cached_property import cached_property

from travel.rasp.bus.api.connectors.client import MetaClient
from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.admin_user import AdminUser
from travel.rasp.bus.db.models.matching import PointMatching, PointType
from travel.rasp.bus.db.models.supplier import Supplier

__all__ = [
    'EndpointsUpdater',
]


def format_coordinate(value, default=None):
    return round(value, 7) if value else default


class BaseSeparator(object):

    def __init__(self, endpoints, point_matchings):
        self.endpoints = {self.endpoint_key(e): e for e in endpoints}
        self.point_matchings = {self.point_matching_key(p): p for p in point_matchings}
        self.endpoints_keys = set(self.endpoints)
        self.point_matchings_keys = set(self.point_matchings)

    @cached_property
    def new_endpoints(self):
        return [self.endpoints[k]
                for k in self.endpoints_keys - self.point_matchings_keys]

    @cached_property
    def intersection(self):
        return [(self.endpoints[k], self.point_matchings[k])
                for k in self.endpoints_keys & self.point_matchings_keys]

    @cached_property
    def outdated_point_matchings(self):
        return [self.point_matchings[k]
                for k in self.point_matchings_keys - self.endpoints_keys]

    @staticmethod
    def endpoint_key(endpoint):
        raise NotImplementedError

    @staticmethod
    def point_matching_key(endpoint):
        raise NotImplementedError


class SupplierIdSeparator(BaseSeparator):

    @staticmethod
    def endpoint_key(endpoint):
        return six.text_type(endpoint.get('supplier_id'))

    @staticmethod
    def point_matching_key(point_matching):
        return point_matching.supplier_point_id


class HeuristicsSeparator(BaseSeparator):

    DEFAULT_ENDPOINT_COORDINATE = object()
    DEFAULT_POINT_MATCHING_COORDINATE = object()

    @classmethod
    def endpoint_key(cls, endpoint):
        return (
            endpoint.get('title'),
            format_coordinate(endpoint.get('latitude'), cls.DEFAULT_ENDPOINT_COORDINATE),
            format_coordinate(endpoint.get('longitude'), cls.DEFAULT_ENDPOINT_COORDINATE)
        )

    @classmethod
    def point_matching_key(cls, point_matching):
        return (
            point_matching.title,
            format_coordinate(point_matching.latitude, cls.DEFAULT_POINT_MATCHING_COORDINATE),
            format_coordinate(point_matching.longitude, cls.DEFAULT_POINT_MATCHING_COORDINATE)
        )


class EndpointsUpdater(object):

    ENDPOINT_POINT_MATCHING_MAP = [
        ('supplier_id', 'supplier_point_id', six.text_type),
        ('type', 'type', lambda x: {'station': PointType.STATION, 'city': PointType.CITY}.get(x, PointType.INVALID)),
        ('title', 'title', lambda x: x),
        ('description', 'description', lambda x: x),
        ('latitude', 'latitude', format_coordinate),
        ('longitude', 'longitude', format_coordinate),
        ('country', 'country', lambda x: x),
        ('country_code', 'country_code', lambda x: x),
        ('city_id', 'city_id', lambda x: x),
        ('city_title', 'city_title', lambda x: x),
        ('region', 'region', lambda x: x),
        ('region_code', 'region_code', lambda x: x),
        ('district', 'district', lambda x: x),
        ('extra_info', 'extra_info', lambda x: x),
        ('timezone_info', 'timezone_info', lambda x: x),
    ]

    class Stat(object):

        def __init__(self):
            self.processed = 0
            self.created = []
            self.updated = []
            self.outdated = []
            self.suppliers = []
            self.suppliers_failed = []
            self.outdate_changes = dict()
            self.in_segments_changes = dict()

        def report(self, verbose=False):
            data = ['Endpoints update statistics for: {}'.format(', '.join(self.suppliers)),
                    'Processed: {}'.format(self.processed)]
            if self.suppliers_failed:
                data.append('Failed for: {}'.format(', '.join(self.suppliers_failed)))
            for f in ('Created', 'Updated', 'Outdated'):
                data.append('{}: {}'.format(f, len(getattr(self, f.lower()))))
                if verbose:
                    sep = '\n' if f == 'Updated' else '|'
                    data += [sep.join(sorted(getattr(self, f.lower()))), '']
            return '\n'.join(data)

    def __init__(self):
        self.stat = self.Stat()
        self.raw_segments_points = {}
        self.fake_segments = {}
        self.settlements_only = {}

    def _update_point_matching(self, point_matching, updated_by, endpoint=None, created=False, outdated=False):

        if not outdated and endpoint is None:
            logging.error('Not outdated and no endpoint')
            return

        if endpoint is not None and (any(f not in endpoint for f in ('supplier_id', 'title', 'type'))):
            logging.error('Validation error: %r', endpoint)
            return

        if point_matching.supplier_id is None:
            logging.error('supplier_id is None')
            return

        point_matching.updated_by = updated_by

        if outdated:
            if point_matching.outdated is False:  # for catching only changes of outdated state
                self.stat.outdate_changes[point_matching.id] = True
                self.stat.updated.append('id={}|outdated:False->True'.format(point_matching.id))

            point_matching.outdated = True
            self.stat.outdated.append('id={}'.format(point_matching.id))
            return

        update_log = []
        for endpoint_field, point_matching_field, convert in self.ENDPOINT_POINT_MATCHING_MAP:
            cur_value = getattr(point_matching, point_matching_field)
            new_value = convert(endpoint[endpoint_field])
            if cur_value != new_value:
                update_log.append('{}:{!r}->{!r}'.format(point_matching_field, cur_value, new_value))
                setattr(point_matching, point_matching_field, new_value)

        # in_segments handling
        if point_matching.supplier_id in self.raw_segments_points:
            if self.fake_segments[point_matching.supplier_id] or \
                    (self.settlements_only[point_matching.supplier_id] and point_matching.type == PointType.STATION):
                in_segments = None
            else:
                in_segments = point_matching.supplier_point_id in self.raw_segments_points[point_matching.supplier_id]
            if point_matching.in_segments != in_segments:
                self.stat.in_segments_changes[point_matching.id] = in_segments
                self.stat.updated.append('id={}|in_segments:{}->{}'.format(
                    point_matching.id,
                    point_matching.in_segments,
                    in_segments
                ))
            point_matching.in_segments = in_segments

        if point_matching.outdated:
            self.stat.outdate_changes[point_matching.id] = False
            point_matching.outdated = False
            update_log.append('outdated:True->False')

        if created:
            self.stat.created.append('s_id={},p_id={}'.format(
                point_matching.supplier_id, endpoint['supplier_id']))
        elif update_log:
            self.stat.updated.append('id={}|{}'.format(point_matching.id, '|'.join(update_log)))

    def run(self, login, supplier_codes=None, dry=False):
        if dry:
            logging.info('Dry run')

        all_endpoints_loaded = True
        with session_scope() as session:

            user = session.query(AdminUser).get(login)
            if user is None:
                logging.error('No user %s', login)
                raise Exception('Bad params')

            suppliers = list(session.query(Supplier).filter(Supplier.hidden == False))  # noqa: E712
            for supplier in suppliers:
                if supplier_codes and supplier.code not in supplier_codes:
                    continue

                self.stat.suppliers.append(supplier.code)
                logging.info('Endpoints updating for %s', supplier.code)

                endpoints = list(MetaClient.endpoints(supplier.code, raise_on_exception=False))
                raw_segments_response = MetaClient.raw_segments(supplier.code, raise_on_exception=False)

                if not raw_segments_response:
                    logging.warning('can not get segments for %s', supplier.code)
                else:
                    fake_segments = raw_segments_response.get('fake_segments')
                    settlements_only = raw_segments_response.get('settlements_only')
                    raw_segments = raw_segments_response['segments']
                    logging.info('got raw segments for %s: count %s, fake_segments %s, settlements_only %s',
                                 supplier.code, len(raw_segments), fake_segments, settlements_only)

                    points_in_segments = set()
                    for from_sid, to_sid in raw_segments:
                        points_in_segments.add(from_sid)
                        points_in_segments.add(to_sid)

                    logging.info('num of points in raw segments for %s: %s', supplier.code, len(points_in_segments))
                    self.raw_segments_points[supplier.id] = points_in_segments
                    self.fake_segments[supplier.id] = fake_segments
                    self.settlements_only[supplier.id] = settlements_only

                if len(endpoints) == 0:
                    logging.error('empty endpoints response for {}'.format(supplier.code))
                    self.stat.suppliers_failed.append(supplier.code)
                    all_endpoints_loaded = False
                    continue
                self.stat.processed += len(endpoints)

                supplier_id_separator = SupplierIdSeparator(endpoints, supplier.point_matchings)
                heuristics_separator = HeuristicsSeparator(supplier_id_separator.new_endpoints,
                                                           supplier_id_separator.outdated_point_matchings)

                for endpoint in heuristics_separator.new_endpoints:
                    point_matching = PointMatching(supplier_id=supplier.id)
                    self._update_point_matching(point_matching, user, endpoint=endpoint, created=True)
                    session.add(point_matching)

                for endpoint, point_matching in supplier_id_separator.intersection + heuristics_separator.intersection:
                    self._update_point_matching(point_matching, user, endpoint=endpoint)

                for point_matching in heuristics_separator.outdated_point_matchings:
                    self._update_point_matching(point_matching, user, outdated=True)

                if dry:
                    session.rollback()
                else:
                    session.commit()

        logging.info(self.report(verbose=True))
        return all_endpoints_loaded

    def report(self, verbose=False):
        return self.stat.report(verbose=verbose)
