# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import gc

import logging
import multiprocessing  # noqa
import os.path
from collections import defaultdict

from django.conf import settings
from django.db import connection

from common.db.mds.clients import mds_s3_common_client
from common.models.geo import Station, Settlement, Country, StationType
from common.models.transport import TransportType
from common.settings.utils import define_setting
from common.utils.lock import lock
from travel.rasp.library.python.common23.logging import create_current_file_run_log

from travel.rasp.suggests_tasks.suggests.generate import shared_objects
from travel.rasp.suggests_tasks.suggests.generate.caches import (
    StationCodePrecache, SettlementCodePrecache, SynonymsPrecache, precache_django_models)
from travel.rasp.suggests_tasks.suggests.generate.db import raise_if_maintenance_in_process
from travel.rasp.suggests_tasks.suggests.generate.titles import StationWrapper, generate_titles_data
from travel.rasp.suggests_tasks.suggests.generate.ttypes import get_ttypes
from travel.rasp.suggests_tasks.suggests.generate.utils import retrieve_ids
from travel.rasp.suggests_tasks.suggests.objects_utils import ObjIdConverter
from travel.rasp.suggests_tasks.suggests.storage import Storage
from travel.rasp.suggests_tasks.suggests.text_utils import prepare_title_text, TITLE_LANGS
from travel.rasp.suggests_tasks.suggests.utils import enumer, print_run_time


logger = logging.getLogger('generate')

define_setting('SUGGESTS_GENERATE_MDS_CONFIG_PREFIX', default=None)


def precache():
    """Прекэшим всё, что можно, до форка."""

    with print_run_time('Precaching', logger=logger):
        with print_run_time('\tPrecaching models', print_before=False, logger=logger):
            precache_django_models()

        for cache_name in ['station_codes', 'settlement_codes', 'synonyms']:
            with print_run_time('\tPrecaching {}'.format(cache_name), print_before=False, logger=logger):
                shared_objects.get_obj(cache_name).precache()


def ttype_id_to_code(ttype_id):
    if ttype_id in [TransportType.RIVER_ID, TransportType.SEA_ID]:
        return TransportType.objects.get(id=TransportType.WATER_ID).code
    else:
        return TransportType.objects.get(id=ttype_id).code


def prepare_titles_data(titles_data, objects_data):
    """
        - конвертируем db_id в local_id
        - убираем дубли по id (бывает, например sirena_code == synonym == "екб")
        - разбиваем тайтлы по языкам и типам транспорта

        :param titles_data:
        :param objects_data:
        :return:
        {
            'ru': {
                'train': {
                    'title1': {(1, True), (2, False)},
                    'title2': {(1, True), (2, True)},
                },
                ...
            },
            ...
        }
    """
    id_converter = shared_objects.get_obj('id_converter')
    result = defaultdict(lambda: defaultdict(lambda: defaultdict(set)))

    for (obj_type, obj_id), obj_forms in enumer(titles_data.items(), each=10000):
        for obj_form in obj_forms:
            lang = obj_form.get('lang', 'ru')
            title = obj_form['title']
            obj_local_id = id_converter.get_local_id(obj_id, obj_type)
            is_prefix = obj_form.get('is_prefix', True)
            obj_title_key = (obj_local_id, is_prefix)

            result[lang]['all'][title].add(obj_title_key)

            obj_t_types = objects_data[obj_local_id]['t_types']
            for t_type_id in obj_t_types:
                t_type = ttype_id_to_code(t_type_id)
                result[lang][t_type][title].add(obj_title_key)

    return result


def prepare_objects_data(titles_data):
    id_converter = shared_objects.get_obj('id_converter')
    obj_id_to_data = {}
    for obj_id, obj_forms in enumer(titles_data.items(), each=10000):
        obj_type, obj_db_id = obj_id
        obj_local_id = id_converter.get_local_id(obj_db_id, obj_type)

        search_titles = sorted(set(of['title'] for of in obj_forms))

        # данные одного объекта для разных его тайтлов для нас не отличаются,
        # поэтому берем первый попавшийся
        obj_form = next(iter(obj_forms))
        obj_form['local_id'] = obj_local_id
        obj_form['search_titles'] = search_titles
        obj_id_to_data[obj_local_id] = obj_form

    return obj_id_to_data


def prepare_station_prefixes():
    station_prefixes = defaultdict(lambda: defaultdict(set))
    for t_type, st_stypes in StationWrapper.STATION_TYPE_BY_TRANSPORT_TYPE.items():
        for lang in TITLE_LANGS:
            for st_type_id in st_stypes:
                prefix = StationType.objects.get(id=st_type_id).L_name(lang=lang)
                prefix = prepare_title_text(prefix)
                station_prefixes[lang][prefix].add(t_type)

    for lang, prefixes in station_prefixes.items():
        for prefix, t_types in prefixes.items():
            station_prefixes[lang][prefix] = tuple(t_types)

    return station_prefixes


def get_db_info():
    """ Инфа про базу, с которой делалась выгрузка. """
    # TODO: добавить hostname базы
    cursor = connection.cursor()
    cursor.execute("select name, modified_on from flags;")
    return [(name, str(modified_on)) for name, modified_on in cursor]


def generate_objs_data(ids, pool_size=1):
    raw_titles_data = generate_titles_data(ids, pool_size)
    objects_data = prepare_objects_data(raw_titles_data)

    with print_run_time('prepare_titles_data', print_before=False, logger=logger):
        titles_data = prepare_titles_data(raw_titles_data, objects_data)

    return {
        'objects_data': objects_data,
        'titles': titles_data,
        'station_prefixes': prepare_station_prefixes(),
        'db_info': get_db_info(),
    }


def convert_stat_routes_ids(stat_routes):
    """ В статистике по связям из Yt преобразовываем все пары (obj_type, obj_id) -> local_obj_id """

    id_converter = shared_objects.get_obj('id_converter')
    result = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
    for t_type, by_obj_from in stat_routes.items():
        for (obj_from_type_short, obj_from_id), by_obj_to in by_obj_from.items():
            for (obj_to_type_short, obj_to_id), by_geoid in by_obj_to.items():
                obj_from_type = 'settlement' if obj_from_type_short == 'c' else 'station'
                obj_to_type = 'settlement' if obj_to_type_short == 'c' else 'station'

                obj_from_local_id = id_converter.get_local_id(obj_from_id, obj_from_type)
                obj_to_local_id = id_converter.get_local_id(obj_to_id, obj_to_type)
                result[t_type][obj_from_local_id][obj_to_local_id] = by_geoid

    return result


def get_pool_size():
    # pool_size = multiprocessing.cpu_count()
    # TODO: нужно расчитывать в зависимости от имеющейся памяти
    pool_size = 1

    return pool_size


def generate_all(storage, ids_by_model, pool_size=None,
                 skip_stat=False, skip_objs_data=False, skip_ttypes=False,
                 skip_precache=False, load_id_converter=False):

    if pool_size is None:
        pool_size = get_pool_size()

    if load_id_converter:
        id_converter = storage.load_id_converter(freeze=False)
    else:
        id_converter = ObjIdConverter()

    shared_objects.set_objs(id_converter=id_converter)

    if not skip_ttypes:
        with print_run_time('generate ttypes', logger=logger):
            stations_ttypes, settlements_ttypes = get_ttypes(pool_size)

        storage.save_ttypes({
            'stations': stations_ttypes,
            'settlements': settlements_ttypes
        })

    if not skip_stat:
        stat_data = storage.load_stat()  # stat is generated by Yt scripts
        with print_run_time('convert_stat_routes_ids', print_before=False, logger=logger):
            stat_data['routes'] = convert_stat_routes_ids(stat_data['routes'])
            storage.save_stat_converted(stat_data)

    if not skip_objs_data:
        with print_run_time('generate_objs_data', print_before=False, logger=logger):
            ttypes = storage.load_ttypes()
            stat_data = storage.load_stat_converted()

            shared_objects.set_objs(
                station_codes=StationCodePrecache(['iata', 'sirena']),
                settlement_codes=SettlementCodePrecache(['iata', 'sirena']),
                synonyms=SynonymsPrecache([Station, Country, Settlement]),
                stat_weights=stat_data['by_obj'],
                stations_ttypes=ttypes['stations'],
                settlements_ttypes=ttypes['settlements'],
            )

            if not skip_precache:
                precache()
            data = generate_objs_data(ids_by_model, pool_size=pool_size)
            data['stat_routes'] = stat_data['routes']

            storage.save_objs_data(data)
    else:
        data = storage.load_objs_data()

    storage.save_id_converter(id_converter)

    with print_run_time('clean caches', print_before=False, logger=logger):
        id_converter = None
        stat_data = None
        shared_objects.clear_objs()
        gc.collect()

    return data


@lock('suggests_generate', timeout=3600)
def main(work_dir, skip_gen=False, skip_ttypes=False, skip_stat=False, skip_mds=False, load_id_converter=False, pool_size=None):
    raise_if_maintenance_in_process()

    try:
        with print_run_time('generate', logger=logger):
            ids_by_model = [
                [model, retrieve_ids(model)]
                for model in [
                    Settlement,
                    Station,
                    # Country  # https://st.yandex-team.ru/RASPSUGGESTS-67
                ]
            ]

            generate_all(
                storage=Storage(work_dir),
                ids_by_model=ids_by_model,
                skip_ttypes=skip_ttypes,
                skip_stat=skip_stat,
                skip_objs_data=skip_gen,
                load_id_converter=load_id_converter,
                pool_size=pool_size,
            )

        if skip_mds:
            return

        with print_run_time('mds', logger=logger):
            prefix = settings.SUGGESTS_GENERATE_MDS_CONFIG_PREFIX or settings.MDS_CONFIG['prefix']
            for filename in settings.MDS_CONFIG['files']:
                with open(os.path.join(work_dir, filename), 'rb') as f:
                    mds_s3_common_client.save_data(
                        key='{}/{}'.format(prefix, filename),
                        data=f.read(),
                    )

    except Exception as ex:
        logger.error(repr(ex), exc_info=True)


def run(*args, **kwargs):
    create_current_file_run_log()
    main(*args, **kwargs)


if __name__ == '__main__':
    run()
