# -*- coding: utf-8 -*-

from __future__ import unicode_literals

import logging
import math

from django.db import transaction
from django.utils.functional import cached_property
from django.utils.translation import gettext_noop as N_

from travel.rasp.library.python.common23.date import environment
from common.models.geo import Station
from common.models.schedule import RThread, Route
from common.models.tariffs import ThreadTariff
from cysix.base import CriticalImportError
from cysix.filter_parameters import FilterParameters
from cysix.filters import cysix_xml_validation
from cysix.models import PackageFilter
from travel.rasp.admin.importinfo.models import OriginalThreadData
from travel.rasp.admin.importinfo.models.mappings import StationMapping
from travel.rasp.admin.importinfo.models.two_stage_import import (
    TSISetting, TwoStageImportThread, TwoStageImportStation, TwoStageImportThreadStation
)
from travel.rasp.admin.scripts.utils.file_wrapper.uploaders import upload_tmp_schedule
from travel.rasp.admin.www.utils.mysql import fast_delete_routes
from common.models.geo import Region

log = logging.getLogger(__name__)


class CysixTwoStageImporter(object):
    KNOWN_CODE_SYSTEMS = ('iata', 'express', 'esr', 'sirena', 'icao', 'yandex')

    def __init__(self, factory):
        self.package = factory.package
        self.factory = factory

    @cached_property
    def data_provider(self):
        return self.factory.get_data_provider()

    def reimport_package_into_middle_base(self):
        try:
            self.check_xml()

            self.log_all_filter_parameters()

            self.clean_middle_base()

            log.info(N_('Импортируем пакет в промежуточную базу'))
            supplier_stations_by_tsi_key = self.import_middle_schedule()

            log.info(N_('Обновляем соответствия'))
            self.update_mappings(supplier_stations_by_tsi_key)

            log.info(N_('Импорт завершен успешно'))

        except CriticalImportError, e:
            log.error(N_('В процессе импорта произошла ошибка. %s') % unicode(e))

    def import_middle_schedule(self):
        good_paths = self.get_good_paths()

        supplier_stations_by_keys = self.get_supplier_stations_by_keys(good_paths)

        tsi_stations_by_supplier_station_key = \
            self.get_tsi_stations_by_supplier_station_key(supplier_stations_by_keys)

        self.create_supplier_threads(good_paths, tsi_stations_by_supplier_station_key)

        supplier_stations_by_tsi_key = dict(
            (tsis.key, supplier_stations_by_keys[key])
            for key, tsis in tsi_stations_by_supplier_station_key.items()
        )

        return supplier_stations_by_tsi_key

    def get_good_paths(self):
        good_paths = {}

        for path_group in self.data_provider.get_supplier_path_iter():
            if len(path_group.get_path()) > 1:
                good_paths[path_group.get_key()] = path_group

        return list(good_paths.values())

    def get_supplier_stations_by_keys(self, good_path_groups):
        supplier_stations_by_keys = {}

        for path_group in good_path_groups:
            for supplier_station in path_group.get_path():
                if supplier_station.key not in supplier_stations_by_keys:
                    supplier_stations_by_keys[supplier_station.key] = supplier_station

                elif supplier_station.legacy_stations:
                    added_station = supplier_stations_by_keys[supplier_station.key].legacy_stations

                    added_keys = [s.key for s in added_station]

                    for legacy_station in supplier_station.legacy_stations:
                        if legacy_station.key not in added_keys:
                            added_station.append(legacy_station)

        return supplier_stations_by_keys

    def get_tsi_stations_by_supplier_station_key(self, supplier_stations_by_keys):
        tsi_stations_by_supplier_station_key = dict()

        for key, supplier_station in supplier_stations_by_keys.items():
            station = TwoStageImportStation(
                title=supplier_station.title,
                real_title=supplier_station.real_title,
                geocode_title=supplier_station.geocode_title,
                code=supplier_station.code,
                package=self.data_provider.two_stage_package,
                additional_info=supplier_station.additional_info,
                latitude=supplier_station.latitude,
                longitude=supplier_station.longitude
            )

            log.info(
                N_('Создали станцию title=%s, code=%s, supplier_station_key=%s'),
                supplier_station.title, supplier_station.code, key
            )

            station.save()

            tsi_stations_by_supplier_station_key[key] = station

        return tsi_stations_by_supplier_station_key

    def create_supplier_threads(self, good_path_groups, tsi_stations_by_supplier_station_key):
        for path_group in good_path_groups:
            tsi_stations = []

            for supplier_station in path_group.get_path():
                tsi_stations.append(tsi_stations_by_supplier_station_key[supplier_station.key])

            tsi_thread = TwoStageImportThread()
            tsi_thread.package = self.data_provider.two_stage_package
            tsi_thread.title = " - ".join([tsi_stations[0].title or tsi_stations[0].code,
                                          tsi_stations[-1].title or tsi_stations[-1].code])
            tsi_thread.path_key = path_group.get_key()
            tsi_thread.save()

            for tsi_station in tsi_stations:
                TwoStageImportThreadStation.objects.create(station=tsi_station, thread=tsi_thread)

            group_code = path_group.get_path()[0].group_el.get('code')
            filter_ = self.factory.get_group_factory(group_code).get_filter('set_default_flags')
            filter_.apply(self.package, tsi_thread, tsi_stations)

            log.info(N_('Добавили нитку %s'), tsi_thread.title)

    def update_mappings(self, supplier_stations_by_tsi_key):
        finder = self.get_finder()

        for tsi_station in TwoStageImportStation.objects.filter(package=self.package):
            supplier_station = supplier_stations_by_tsi_key.get(tsi_station.key)

            self.create_exact_mapping_if_not_exist(supplier_station, finder)

            self.delete_exes_exact_mapping(tsi_station, finder)

            self.set_mapping(tsi_station, finder)

    def create_exact_mapping_if_not_exist(self, cysix_station, finder):
        exact_mappings = finder.find_exact_mappings(cysix_station)

        if exact_mappings:
            return

        if cysix_station.station_code_system in self.KNOWN_CODE_SYSTEMS:
            try:
                station = Station.get_by_code(cysix_station.station_code_system, cysix_station.station_code)
                station_mapping = StationMapping.objects.create(
                    title=cysix_station.title,
                    code=cysix_station.code,
                    station=station,
                    supplier=self.package.supplier
                )
                log.info(N_('Создали новую привязку <%s> по коду %s в системе кодирования %s'),
                         station_mapping, cysix_station.station_code, cysix_station.station_code_system)
            except Station.DoesNotExist:
                log.error(N_('Не нашли станции в системе %s по коду %s'),
                          cysix_station.station_code_system, cysix_station.station_code)

        elif self.allow_station_mapping_by_code(cysix_station):
            mappings = finder.find_exact_mappings_by_code_only(cysix_station)

            if len(set(mapping.station for mapping in mappings)) == 1:
                station = mappings[0].station

                station_mapping = StationMapping.objects.create(
                    title=cysix_station.title,
                    code=cysix_station.code,
                    station=station,
                    supplier=self.package.supplier
                )
                log.info(N_('Создали новую привязку <%s>, используя только код'), station_mapping)

        else:
            log.info(N_('Не нашли привязки для %s'), cysix_station)

    def allow_station_mapping_by_code(self, cysix_station):
        factory = self.factory.get_group_factory(cysix_station.group_el.get('code'))
        filter_ = factory.get_package_filter_obj('allow_station_mapping_by_code')

        return filter_ and filter_.use

    def delete_exes_exact_mapping(self, tsi_station, finder):
        exact_mappings = finder.find_exact_mappings(tsi_station)

        if len(exact_mappings) > 1:
            for m in exact_mappings[1:]:
                m.delete()

    def set_mapping(self, tsi_station, finder):
        exact_mappings = finder.find_exact_mappings(tsi_station)

        if exact_mappings:
            tsi_station.station_mapping = exact_mappings[0]
            tsi_station.save()

    def reimport_package(self):
        log.info(N_('============== Импортируем пакет %s'), self.package.title)
        try:
            self.check_xml()

            self.log_all_filter_parameters()

            self.import_package_routes()

            self.post_import_routes()

            self.package.last_import_date = environment.today()
            self.package.last_import_datetime = environment.now()
            self.package.update_last_mask_date()
            self.package.save()

            self.package.generate_directions_slices()

            log.info(N_('Импорт завершен успешно'))

        except CriticalImportError as e:
            log.error(N_('В процессе импорта произошла ошибка. %s') % unicode(e))

        log.info(N_('============== Завершен импорт пакета %s'), self.package.title)

    def import_package_routes(self):
        route_importer = self.factory.get_route_importer()
        route_importer.do_import()

    def post_import_routes(self):
        self.change_supplier_if_dm()

    def get_finder(self):
        return self.factory.get_station_finder()

    def get_file_provider(self):
        return self.factory.get_file_provider()

    def get_package_file_provider(self):
        raise NotImplementedError()

    def get_download_file_provider(self):
        raise NotImplementedError()

    def check_xml(self):
        filepath = self.factory.get_file_provider().get_cysix_file()

        log.info(N_("Проверяем файл '%s'"), filepath)
        cysix_xml_validation.validate_cysix_xml(filepath)

        upload_tmp_schedule(self.package)

    def log_all_filter_parameters(self):
        log.info('+++++++')

        for field in TSISetting._meta.fields:
            if field.name in ('id', 'package'):
                continue
            log.info('%s: %s', field.verbose_name, getattr(self.package.tsisetting, field.name))

        log.info('+++++++')

        package_filters = PackageFilter.objects.filter(package=self.package).order_by('filter__order')

        for package_filter in package_filters:
            f = package_filter.filter

            log.info(N_("Имя фильтра: '%s'"), f.title)
            log.info(N_("Код фильтра: '%s'"), f.code)
            log.info(N_("Использовать фильтр: '%s'"), package_filter.use)
            log.info(N_('Параметры фильтра:'))

            if package_filter.parameters:
                params = package_filter.parameters.get_parameters()

                for p in params:
                    if p['type'] not in ['begin_block', 'end_block']:
                        title = p['title']
                        value = FilterParameters._value_to_type(p['value'], p['type'])
                        log.info('\t%s: %s', title, value)

        log.info('+++++++')

    def clean_middle_base(self):
        log.info(N_('Удаляем старые промежуточные данные'))
        TwoStageImportThread.objects.filter(package=self.package).delete()
        TwoStageImportStation.objects.filter(package=self.package).delete()

    def has_unbinded_stations(self):
        return bool(self.get_unbinded_stations())

    def get_unbinded_stations(self):
        bad_stations = []

        if self.package.create_stations:
            self.create_unbinded_stations()
        else:
            finder = self.get_finder()
            for tsi_station in TwoStageImportStation.objects.filter(package=self.package):
                mappings = finder.find_exact_mappings(tsi_station)

                if not mappings:
                    bad_stations.append((tsi_station, 0))
                elif len(mappings) > 1:
                    bad_stations.append((tsi_station, len(mappings)))

        return bad_stations

    def create_unbinded_stations(self, region_id=None, settlement_id=None):
        finder = self.get_finder()
        region_id = region_id or self.package.region_id or Region.MOSCOW_REGION_ID
        settlement_id = settlement_id or self.package.settlement_id

        for tsi_station in TwoStageImportStation.objects.filter(package=self.package):
            finder.find_or_create_by_tsi_station(tsi_station, region_id, settlement_id)

    def remove_package_routes(self):
        log.info(N_('Удаляем тарифы пакета'))
        affected_thread_uids = (RThread.objects.filter(route__two_stage_package=self.package)
                                               .values_list('uid', flat=True))
        ThreadTariff.objects.filter(thread_uid__in=affected_thread_uids).delete()

        OriginalThreadData.fast_delete_by_package_id(self.package.id, log)

        log.info(N_('Удаляем маршруты пакета'))
        fast_delete_routes(Route.objects.filter(two_stage_package=self.package))

    def search_duplicates(self):
        log.info(N_('Ищем дубликаты маршрутов пакета'))
        self.search_path_partial_duplicates()

        self.search_full_duplicates()

    def search_path_partial_duplicates(self):
        finder = self.get_finder()
        duplicates = []

        threads = list(TwoStageImportThread.objects.filter(package=self.package))

        for t in threads:
            t.path_length = t.threadstations.all().count()

            t.tsi_first = t.threadstations.select_related('station')[0].station
            t.tsi_last = t.threadstations.select_related('station').order_by('-id')[0].station

            t.first_station = self._get_station(t.tsi_first, finder)
            t.last_station = self._get_station(t.tsi_last, finder)

        checked = set()

        for index, thread1 in enumerate(threads[:-1]):
            tsi_path1 = [tsi_rts.station for tsi_rts in thread1.threadstations.select_related('station')]

            path1 = self._get_real_path_from_tsi_path(tsi_path1, finder)

            thread1_duplicates = []
            for thread2 in threads[index + 1:]:
                if thread2 in checked:
                    continue

                if thread1.path_length != thread2.path_length:
                    continue

                if path1[0] != thread2.first_station:
                    continue

                if path1[-1] != thread2.last_station:
                    continue

                tsi_path2 = [tsi_rts.station for tsi_rts in thread2.threadstations.select_related('station')]

                path2 = self._get_real_path_from_tsi_path(tsi_path2, finder)

                diff_pairs_count = 0

                for tsi_station1, station1, tsi_station2, station2 in zip(tsi_path1, path1, tsi_path2, path2):
                    tsi_station1.has_diff = tsi_station2.has_diff = False
                    tsi_station1.station = station1
                    tsi_station2.station = station2

                    if station1 != station2:
                        diff_pairs_count += 1

                        tsi_station1.has_diff = True
                        tsi_station2.has_diff = True

                if diff_pairs_count > math.ceil(len(path1) / 2.0):
                    continue

                if not diff_pairs_count:
                    continue

                thread2.tsi_path = tsi_path2

                thread1_duplicates.append(thread2)
                checked.add(thread2)

            if thread1_duplicates:
                thread1.tsi_path = tsi_path1
                thread1_duplicates.append(thread1)

                for i in range(len(thread1.tsi_path)):
                    tsi_stations = []
                    has_diff = False

                    for thread in thread1_duplicates:
                        tsi_stations.append(thread.tsi_path[i])

                        has_diff = has_diff or thread.tsi_path[i].has_diff

                    if has_diff:
                        for s in tsi_stations:
                            s.has_diff = True

                duplicates.append(thread1_duplicates)

        if duplicates:
            log.info(N_('Частичные дубликаты'))
            log.info('------------')

            for threads in duplicates:

                log.info(N_('Нитки дубликаты:'))
                for t in threads:
                    log.info("\t'%s' - '%s'", t.id, t.title)

                for i in range(len(threads[0].tsi_path)):
                    tsi_stations = []
                    has_diff = False

                    for thread in threads:
                        tsi_stations.append(thread.tsi_path[i])

                        has_diff = has_diff or thread.tsi_path[i].has_diff

                    if has_diff:
                        log.info(N_('Остановка номер %s:'), i)

                        for tss in tsi_stations:
                            log.info("\t{}".format(self._generate_string_for_station(tss)))
                    else:
                        log.info(N_('Остановка номер %s: %s'), i, self._generate_string_for_station(tsi_stations[0]))

                log.info('------------')

    def _generate_string_for_station(self, tss):
        return "'{}' - '{}' ('{}' - '{})'".format(
            tss.station and tss.station.id, tss.station and tss.station.title, tss.code, tss.title)

    def _get_real_path_from_tsi_path(self, tsi_path, finder):
        path = []
        for tsi_station in tsi_path:
            path.append(self._get_station(tsi_station, finder))

        return path

    def _get_station(self, tsi_station, finder):
        try:
            return finder.find_by_supplier_station(tsi_station)
        except (Station.DoesNotExist, Station.MultipleObjectsReturned):
            return None

    def search_full_duplicates(self):
        duplicates = []

        threads = list(RThread.objects.filter(route__two_stage_package=self.package))

        checked = set()

        for index, thread1 in enumerate(threads[:-1]):
            if thread1 in checked:
                continue

            path1 = list(thread1.path.select_related('station'))

            thread1_duplicates = []
            for thread2 in threads[index + 1:]:
                if thread1.title != thread2.title:
                    continue

                if thread1.start_time != thread2.start_time:
                    continue

                mask_intersection = thread1.mask() & thread2.mask()

                if not mask_intersection:
                    continue

                path2 = list(thread2.path.select_related('station'))

                if path1[-1].arrival != path2[-1].arrival:
                    continue

                if len(path1) != len(path2):
                    continue

                # Если есть разные станции
                if filter(lambda (rts1, rts2): rts1.station_id != rts2.station_id, zip(path1, path2)):
                    continue

                thread1_duplicates.append(thread2)
                checked.add(thread2)

            if thread1_duplicates:
                thread1_duplicates.append(thread1)

                duplicates.append(thread1_duplicates)

        if duplicates:
            log.info(N_('Полные дубликаты c пересечением по дням хождения'))

            for thread_duplicates in duplicates:
                log.info(", ".join("%s: %s" % (t.uid, t.title) for t in thread_duplicates))

    def clean_legacy_mappings(self):
        pass

    def clean_legacy_mappings_from_supplier_station(self, supplier_station):
        pass

    @transaction.atomic
    def change_supplier_if_dm(self):
        """
        TODO: Убрать этот хак, как можно скорее
        RASPFRONT-761 — Прописывать принудительно ниткам от Валдина поставщика МТА
        """

        if self.package.supplier.code == 'dm':
            from django.db import connection

            mta_id = 44

            log.info('RASPFRONT-761 — Прописывать принудительно ниткам от Валдина поставщика МТА')

            cursor = connection.cursor()

            cursor.execute('UPDATE www_route SET supplier_id = {mta_id} WHERE supplier_id = {dm_id}'.format(
                mta_id=mta_id, dm_id=self.package.supplier.id
            ))
            cursor.execute('UPDATE www_rthread SET supplier_id = {mta_id} WHERE supplier_id = {dm_id}'.format(
                mta_id=mta_id, dm_id=self.package.supplier.id
            ))
