# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
import re
from collections import defaultdict
from copy import copy
from itertools import groupby, combinations_with_replacement, combinations

from django.db.models.fields.related import ForeignKey
from django.conf import settings

from common.settings.utils import define_setting
from common.models.geo import (
    Station, Settlement, StationType, CityMajority, Country, StationMajority, Region, ExternalDirectionMarker
)
from common.models.transport import TransportType

from travel.rasp.suggests_tasks.suggests.generate.countries import COUNTRY_MAJORITY, OUR_COUNTRIES, ALL_IMPORTANT_COUNTRIES
from travel.rasp.suggests_tasks.suggests.generate import shared_objects
from travel.rasp.suggests_tasks.suggests.generate.utils import generate_parallel
from travel.rasp.suggests_tasks.suggests.objects_utils import get_obj_type
from travel.rasp.suggests_tasks.suggests.text_utils import SEPARATORS, prepare_title_text, TITLE_LANGS, TITLE_NATIONAL_VERSIONS
from travel.rasp.suggests_tasks.suggests.utils import print_run_time, split_values, enumer

# Option for bus suggest
define_setting('IGNORE_HIDDEN_FIELD', default=False)


MAX_STATION_MAJ = 4
MAX_SETTLEMENT_MAJ = 6

UNKNOWN_WEIGHT = 1

logger = logging.getLogger('generate')

stations_directions = {}


class GeoWeights(object):
    def __init__(self, base, region, country, settlement):
        self.base = base  # без geo_id
        self.region = region
        self.country = country
        self.settlement = settlement

    def __mul__(self, coef):
        for key in ['base', 'region', 'country', 'settlement']:
            value = getattr(self, key)
            setattr(self, key, value * coef)

        return self

    def __eq__(self, other):
        return all(getattr(self, key) == getattr(other, key)
                   for key in ['base', 'region', 'country', 'settlement'])

    def __repr__(self):
        return 'base={}, country={}, region={}, settlement={}'.format(
            self.base, self.country, self.region, self.settlement)


def title_variants(s):
    """
    Возвращает варианты названия:
    ростов-на-дону -> [("ростов на дону", True), ("на дону", False), ("дону", False)]
    """
    parts = []
    for is_separator, value in groupby(s, lambda c: c in SEPARATORS):
        if is_separator:
            continue

        parts.append(u''.join(value))

    for i, part in enumerate(parts):
        variant = u' '.join([p for p in parts[i:]])

        # второе возвращаемое значение == True, если данный вариант с начала тайтла, а не с середины
        yield variant, i == 0


class ObjSuggestWrapper(object):
    def __init__(self, obj):
        self.obj = obj  # original Django model object

    def get_object_forms(self):
        raise NotImplementedError


class SuggestUnload(object):
    def __init__(self, objects):
        self.objects_iter = objects
        self.suggests_data = defaultdict(list)

    def get_obj_wrapper(self, obj):
        if obj.__class__ is Station:
            return StationWrapper(obj)
        elif obj.__class__ is Settlement:
            return SettlementWrapper(obj)
        elif obj.__class__ is Country:
            return CountryWrapper(obj)

    def additional_obj_titles(self, title):
        # https://st.yandex-team.ru/RASPSUGGESTS-18
        # все комбинации тайтлов с Ё
        e_positions = [match.start() for match in re.finditer(u'ё', title)]
        if not e_positions:
            yield title
            return

        e_combs = combinations_with_replacement((u'е', u'ё'), len(e_positions))
        title_list = list(title)
        for comb in e_combs:
            for i, letter in enumerate(comb):
                title_list[e_positions[i]] = letter
            yield u''.join(title_list)

    def generate_suggests_data(self):
        for raw_obj in self.objects_iter:
            obj = self.get_obj_wrapper(raw_obj)
            for base_obj_form in obj.get_object_forms():
                for title in self.additional_obj_titles(base_obj_form['title']):
                    obj_form = copy(base_obj_form)
                    obj_form['title'] = title
                    self.suggests_data[(get_obj_type(raw_obj), raw_obj.id)].append(obj_form)


class BaseWrapper(ObjSuggestWrapper):
    def get_object_forms_by_separators(self, title, base_data):
        title = prepare_title_text(title)
        for title_variant, is_prefix in title_variants(title):
            data = dict(
                title=title_variant,
                is_prefix=is_prefix,
                **base_data)
            yield data

    def get_object_forms(self):
        base_data = self.get_base_data()
        for lang, title in self.get_titles_by_lang().items():
            for obj_form in self.get_object_forms_by_separators(title, base_data):
                obj_form['lang'] = lang
                yield obj_form

        for title in self.get_titles_synonyms():
            title = prepare_title_text(title)
            data = dict(title=title, syn=True, **base_data)
            yield data

    def get_ttypes(self):
        return []

    def get_base_data(self):
        weights = self.get_stat_weights()

        synth_weights = self.get_synth_weights()
        for key, value in synth_weights.items():
            synth_weights[key] = int(value)  # don't need precision (and floats at all)
        weights['synth'] = synth_weights
        object_params = {
            'weights': weights,
            'obj_type': get_obj_type(self.obj),
            'obj_id': self.obj.id,
            'point_key': self.obj.point_key,
            't_types': self.get_ttypes(),
        }

        titles = {lang: self.obj.L_title(lang=lang) for lang in TITLE_LANGS}
        object_params['titles'] = titles

        full_titles = defaultdict(dict)
        country_titles = defaultdict(dict)
        country_id = {}
        national_versions = (
            [(v, v) for v in TITLE_NATIONAL_VERSIONS]
            if getattr(self.obj, 'disputed_territory', False) else
            [('default', 'ru')]
        )
        for national_alias, national_v in national_versions:
            country = self.obj.translocal_country(national_v) if hasattr(self.obj, 'translocal_country') else None
            if country:
                country_id[national_alias] = country.id
            for lang in TITLE_LANGS:
                full_titles[lang][national_alias] = self.get_omonim_title(lang, national_v)
                if country:
                    country_titles[lang][national_alias] = country.L_title(lang=lang)

        object_params['full_titles'] = full_titles
        if country_titles:
            object_params['country_titles_national'] = country_titles
        if country_id:
            object_params['country_id_national'] = country_id
        return object_params

    def get_omonim_title(self, lang=None, national_version=None):
        omonim_title = self.obj.L_omonim_title(lang=lang, national_version=national_version)
        parts = [omonim_title['title']]
        if omonim_title['add']:
            parts.append(omonim_title['add'])

        return u', '.join(parts)

    def get_synth_weights(self):
        raise NotImplementedError

    def get_stat_weights(self):
        stat_weights = shared_objects.get_obj('stat_weights')
        if get_obj_type(self.obj) == 'settlement':
            pref = 'c'
        elif get_obj_type(self.obj) == 'station':
            pref = 's'
        else:
            return {None: {'all': UNKNOWN_WEIGHT}}

        return stat_weights.get((pref, self.obj.id), {})

    def get_title(self, **kwargs):
        lang = kwargs.pop('lang', None)
        return self.obj.L_title(lang=lang)

    def get_titles_by_lang(self, **kwargs):
        return {lang: self.get_title(lang=lang, **kwargs) for lang in TITLE_LANGS}

    def get_titles_synonyms(self):
        return [syn.title for syn in shared_objects.get_obj('synonyms').get(self.obj)]


def titles_by_system_codes(base_wrapper, codes):
    base_data = base_wrapper.get_base_data()
    for system_code, obj_code in codes:
        if system_code == 'sirena':
            langs = ['ru', 'uk']
        else:
            langs = ['ru', 'uk', 'en']

        for lang in langs:
            data = dict(
                title=prepare_title_text(obj_code),
                system=system_code,
                lang=lang,
                comment='code system {}'.format(system_code),
                **base_data)

            yield data


class SettlementWrapper(BaseWrapper):
    def get_base_data(self):
        data = super(SettlementWrapper, self).get_base_data()
        data['zone_id'] = self.obj.suburban_zone_id
        data['region_id'] = self.obj.region_id
        data['majority_id'] = self.obj.majority_id
        data['disputed_territory'] = bool(self.obj.disputed_territory)
        data['codes'] = {'iata': self.obj.iata,
                         'sirena': self.obj.sirena_id}
        data['slug'] = self.obj.slug
        if self.obj.region_id:
            region = Region.objects.get(id=self.obj.region_id)
            data['region_titles'] = {lang: region.L_title(lang=lang) for lang in TITLE_LANGS}
        return data

    def get_ttypes(self):
        return shared_objects.get_obj('settlements_ttypes').get(self.obj.id, [])

    def get_synth_weights(self):
        geo_weights = self.geo_weights(boost=True)

        weights = {None: geo_weights.base}
        if self.obj.region and self.obj.region._geo_id:
            weights[self.obj.region._geo_id] = geo_weights.region

        if self.obj.country and self.obj.country._geo_id:
            weights[self.obj.country._geo_id] = geo_weights.country

        if self.obj._geo_id:
            weights[self.obj._geo_id] = geo_weights.settlement

        return weights

    def majority(self):
        if self.obj.big_city:
            return CityMajority.CAPITAL_ID
        if self.obj.majority_id is None:
            return CityMajority.COMMON_CITY_ID
        return self.obj.majority_id

    def geo_weights(self, boost=True):
        maj = self.majority()
        if maj > MAX_SETTLEMENT_MAJ:
            maj = MAX_SETTLEMENT_MAJ

        maj_coef = (MAX_SETTLEMENT_MAJ + 1) - maj
        geo_weights = GeoWeights(**settings.SETT_GEO_WEIGHTS) * maj_coef

        if boost:
            if self.obj.country in OUR_COUNTRIES and maj == CityMajority.CAPITAL_ID:
                geo_weights *= settings.BOOST.OUR_COUNTRY_CAPITAL
            else:
                # столицы важных стран
                if maj == CityMajority.CAPITAL_ID and self.obj.country in ALL_IMPORTANT_COUNTRIES:
                    geo_weights.base *= settings.BOOST.IMPORTANT_COUNTRY_CAPITAL

                # большие и важные города местоположения пользователя
                large_cities = [CityMajority.REGION_CAPITAL_ID, CityMajority.POPULATION_MILLION_ID, CityMajority.CAPITAL_ID]
                if maj in large_cities:
                    geo_weights.region *= settings.BOOST.LARGE_CITIES
                    geo_weights.country *= settings.BOOST.LARGE_CITIES

        return geo_weights

    def get_object_forms(self):
        for data in super(SettlementWrapper, self).get_object_forms():
            yield data

        codes = []
        if self.obj.iata:
            codes.append(['iata', self.obj.iata])

        for data in titles_by_system_codes(self, codes):
            yield data


class StationWrapper(BaseWrapper):
    ST = StationType

    WATER_TYPES = [ST.PORT_ID, ST.PORT_POINT_ID, ST.WHARF_ID]
    TRAIN_TYPES = [ST.STATION_ID, ST.PLATFORM_ID, ST.STOP_ID, ST.CROSSING_ID, ST.POST_ID, ST.TRAIN_STATION_ID]

    # https://st.yandex-team.ru/RASPSUGGESTS-21
    # TODO: добавить "остановка", "автостанция", "причал"
    STATION_TYPE_BY_TRANSPORT_TYPE = {
        TransportType.TRAIN_ID: TRAIN_TYPES,
        TransportType.PLANE_ID: [ST.AIRPORT_ID],
        TransportType.BUS_ID: [ST.BUS_STATION_ID, ST.BUS_STOP_ID, ],
        TransportType.RIVER_ID: WATER_TYPES,
        TransportType.SEA_ID: WATER_TYPES,
        TransportType.WATER_ID: WATER_TYPES,
        TransportType.SUBURBAN_ID: TRAIN_TYPES,
    }

    def get_synth_weights(self):
        if self.obj.settlement:
            settlement = self.obj.settlement
        else:
            settlement = Settlement(majority_id=CityMajority.COMMON_CITY_ID + 1)

        need_boost = self.obj.station_type_id == StationType.AIRPORT_ID
        sett_geo_weights = SettlementWrapper(settlement).geo_weights(boost=need_boost)

        geo_weights = sett_geo_weights * self.maj_coeff()
        weights = {None: geo_weights.base}
        if self.obj.region and self.obj.region._geo_id:
            weights[self.obj.region._geo_id] = geo_weights.region

        if self.obj.country and self.obj.country._geo_id:
            weights[self.obj.country._geo_id] = geo_weights.country

        if self.obj.settlement and self.obj.settlement._geo_id:
            weights[self.obj.settlement._geo_id] = geo_weights.settlement

        return weights

    def maj_coeff(self):
        maj = self.obj.majority_id

        if maj > MAX_STATION_MAJ:
            maj = MAX_STATION_MAJ

        coef = (MAX_STATION_MAJ + 1) - maj

        # приводим coef к числу в заданном интервале нормализации
        min_coef, max_coef = 1.0, float(MAX_STATION_MAJ)
        assert min_coef <= coef <= max_coef

        min_coef_norm, max_coef_norm = settings.STATION_NORM_COEF
        coef_norm_len = max_coef_norm - min_coef_norm
        coef_len = max_coef - min_coef

        return min_coef_norm + coef_norm_len * ((coef - min_coef) / coef_len)

    def get_ttypes(self):
        station_ttypes = shared_objects.get_obj('stations_ttypes')
        return station_ttypes.get(self.obj.id, [])

    def get_base_data(self):
        data = super(StationWrapper, self).get_base_data()
        data['t_type'] = self.obj.t_type_id
        data['t_type_code'] = self.obj.t_type.code
        data['region_id'] = self.obj.region_id
        data['zone_id'] = self.obj.suburban_zone_id
        data['majority_id'] = self.obj.majority_id
        data['disputed_territory'] = bool(self.obj.disputed_territory)
        data['codes'] = {'iata': self.obj.iata,
                         'sirena': self.obj.sirena_id}
        data['slug'] = self.obj.slug
        if self.obj.station_type_id:
            station_type = self.obj.station_type
            data['station_types'] = {lang: station_type.L_name(lang=lang) for lang in TITLE_LANGS}
        data['popular_titles'] = {}
        for lang in TITLE_LANGS:
            popular_title = self.obj.L_popular_title(lang=lang)
            if popular_title and popular_title != self.obj.L_title(lang=lang):
                data['popular_titles'][lang] = popular_title
        if self.obj.region_id:
            region = Region.objects.get(id=self.obj.region_id)
            data['region_titles'] = {lang: region.L_title(lang=lang) for lang in TITLE_LANGS}
        if self.obj.settlement_id:
            settlement = Settlement.objects.get(id=self.obj.settlement_id)
            data['settlement_titles'] = {lang: settlement.L_title(lang=lang) for lang in TITLE_LANGS}

        external_direction = stations_directions.get(self.obj.id)
        if external_direction:
            data['suburban_directions'] = {lang: external_direction.L_full_title(lang=lang) for lang in TITLE_LANGS}

        return data

    def get_object_forms(self):
        base_data = self.get_base_data()
        sett_one_station = shared_objects.get_obj('sett_one_station')

        # Не добавляем базовую форму станции, если она единственная в городе и одноименна с ним.
        # https://st.yandex-team.ru/RASPSUGGESTS-22
        # Но добавляем формы с префиксами для таких станций, чтоб они не пропадали из поиска совсем.
        # Для обычных станций формы с префиксами ищутся на уровне воркера.
        same_as_settlement = self.obj.settlement_id in sett_one_station
        if not same_as_settlement:
            for data in self.forms_for_not_same(base_data):
                yield data
        else:
            for data in self.forms_for_same(base_data):
                yield data

        for data in self.forms_with_popular_title(base_data):
            yield data

        for data in self.forms_by_systems_codes():
            yield data

    def forms_for_not_same(self, base_data):
        for data in super(StationWrapper, self).get_object_forms():
            yield data

        # Для важных станций добавляем префиксы всегда
        types = [StationType.AIRPORT_ID, StationType.BUS_STATION_ID, StationType.TRAIN_STATION_ID]
        majs = [StationMajority.MAIN_IN_CITY_ID, StationMajority.IN_TABLO_ID]
        if self.obj.majority_id in majs and self.obj.station_type_id in types:
            for lang, title in self.get_titles_by_lang().items():
                title = u'{} {}'.format(self.obj.station_type.L_name(lang=lang), title)
                data = dict(
                    title=prepare_title_text(title),
                    lang=lang,
                    comment='important station',
                    **base_data)
                yield data

    def forms_for_same(self, base_data):
        st_types = self.STATION_TYPE_BY_TRANSPORT_TYPE.get(self.obj.t_type_id, [])
        for station_type_id in st_types:
            station_type = StationType.objects.get(id=station_type_id)
            for lang, title in self.get_titles_by_lang().items():
                title = u'{} {}'.format(station_type.L_name(lang=lang), title)
                data = dict(
                    title=prepare_title_text(title),
                    lang=lang,
                    comment='same_as_settlement prefix; {}'.format(lang),
                    **base_data)
                yield data

    def forms_with_popular_title(self, base_data):
        """Народные названия. https://st.yandex-team.ru/RASPSUGGESTS-25"""
        for lang, popular_title in base_data['popular_titles'].items():
            data = dict(
                title=prepare_title_text(popular_title),
                lang=lang,
                comment='popular name',
                **base_data)
            yield data

    def forms_by_systems_codes(self):
        """Для поиска по кодам различных систем."""
        station_codes = shared_objects.get_obj('station_codes')
        codes = [(system_code, station_code.code) for system_code, station_code in station_codes.get(self.obj).items()]
        for data in titles_by_system_codes(self, codes):
            yield data


class CountryWrapper(BaseWrapper):
    def get_base_data(self):
        data = super(CountryWrapper, self).get_base_data()
        return data

    def get_synth_weights(self):
        weight = COUNTRY_MAJORITY.get(self.obj.title, 1)
        weights = {None: weight}
        if self.obj._geo_id:
            weights[self.obj._geo_id] = weight * 2

        return weights


class ModelFieldsGetter(object):
    def __init__(self, model, fields):
        self.model = model
        self.fields = fields

    def get(self, **kwargs):
        qs = self.model.objects.filter(**kwargs).values(*self.fields)
        for data in qs.iterator():
            self.foreign_keys_to_objects(data)
            yield self.model(**data)

    def foreign_keys_to_objects(self, obj_data):
        for field_name, value in obj_data.items():
            if value is not None:
                field = self.model._meta.get_field(field_name)
                if isinstance(field, ForeignKey):
                    obj_data[field_name] = field.related_model.objects.get(id=value)


def stations(**kwargs):
    getter = ModelFieldsGetter(
        Station,
        fields=(
            ['title_{}'.format(lang) for lang in TITLE_LANGS] +
            ['popular_title_{}'.format(lang) for lang in TITLE_LANGS] +
            [
                'title', 'id', 'type_choices', 'sirena_id',
                'majority', 'station_type', 't_type',
                'settlement', 'country', 'region', 'suburban_zone', 'district', 'slug'
            ]
        )
    )

    if not settings.IGNORE_HIDDEN_FIELD:
        kwargs['hidden'] = 0

    return getter.get(**kwargs)


def settlements(**kwargs):
    getter = ModelFieldsGetter(
        Settlement,
        fields=(
            ['title_{}'.format(lang) for lang in TITLE_LANGS] +
            [
                'title', 'majority', 'country', 'region', '_geo_id', 'id', 'district',
                'big_city', 'suburban_zone', '_disputed_territory', 'sirena_id', 'iata', 'slug'
            ]
        )
    )

    if not settings.IGNORE_HIDDEN_FIELD:
        kwargs['hidden'] = 0

    return getter.get(**kwargs)


def countries(**kwargs):
    getter = ModelFieldsGetter(
        Country,
        fields=['title_{}'.format(lang) for lang in TITLE_LANGS] + ['title', '_geo_id', 'id'])

    return getter.get(**kwargs)


model_getters = {
    Settlement: settlements,
    Station: stations,
    Country: countries,
}


def get_titles_data((worker_id, model_ids)):
    def objects():
        for model, ids in model_ids:
            obj_getter = model_getters[model]
            for obj in obj_getter(id__in=ids):
                yield obj

    unload = SuggestUnload(
        objects=enumer(objects(), each=10000, each_time=True, title=unicode(worker_id)),
    )
    unload.generate_suggests_data()
    return worker_id, unload.suggests_data


def get_sett_one_station():
    """ Достаём все города, у которых только одна нескрытая станция, и её имя совпадает с именем города. """
    sett_one_station = list(Settlement.objects.raw("""
        select s.id, s.title, count(st.id) cst, st.id station_id, st.title station_title, st.station_type_id
        from www_settlement s
        join www_station st on st.settlement_id = s.id
        where s.hidden = 0 and st.hidden = 0
        group by s.id having cst = 1 and title = station_title
    """))

    return sett_one_station


def merge_suggests_data(dict_to, data):
    for key, obj_forms in data.items():
        current_forms = dict_to.setdefault(key, [])
        current_forms.extend(obj_forms)

    return dict_to


def generate_titles_data(ids, pool_size=1):
    shared_objects.set_objs(sett_one_station={s.id for s in get_sett_one_station()})
    with print_run_time('save directions', logger=logger):
        save_stations_directions()
    with print_run_time('titles generation', logger=logger):
        if pool_size == 1:
            worker_id, titles_data = get_titles_data((0, ids))
        else:
            data_by_worker = defaultdict(list)
            for model, ids in ids:
                for worker_id, ids_slice in split_values(ids, pool_size).items():
                    data_by_worker[worker_id].append((model, ids_slice))

            titles_data = {}
            for worker_id, result in generate_parallel(get_titles_data, data_by_worker, pool_size):
                with print_run_time('merge for {}'.format(worker_id), logger=logger):
                    titles_data = merge_suggests_data(titles_data, result)

        logger.info('len(titles_data): {}'.format(len(titles_data)))

    return titles_data


def save_stations_directions():
    need_directions = set()
    rail_transport = [TransportType.TRAIN_ID, TransportType.SUBURBAN_ID]
    for station_1, station_2 in combinations(stations(t_type_id__in=rail_transport), 2):
        if (station_1.title == station_2.title and
                station_1.region_id == station_2.region_id):
            need_directions.update([station_1.id, station_2.id])

    markers = ExternalDirectionMarker.objects.filter(station__id__in=need_directions)

    direction_markers = {}
    for marker in markers:
        if marker.station.suburban_zone_id == marker.external_direction.suburban_zone_id:
            direction_markers[marker.station_id] = marker

    global stations_directions
    stations_directions = {}
    for station_id in need_directions:
        marker = direction_markers.get(station_id)
        if marker:
            stations_directions[station_id] = marker.external_direction
