# coding: utf-8
from __future__ import absolute_import, division, print_function

import logging
from collections import defaultdict
from contextlib import closing
from os.path import join

import MySQLdb

from travel.rasp.mysql_dumper.lib.loaders import BaseDBLoader, MySQLConnector
from travel.rasp.mysql_dumper.lib.protos.rasp_precache_pb2 import TRaspPrecache

CACHE = TRaspPrecache()
FIELDS = defaultdict(lambda: 'distinct(time_zone)', {'ThreadTariffTimezoneLoader': 'distinct(time_zone_from)'})


class RaspPrecacher(object):
    NAMES = ('Settlement', 'Station', 'RTStation', 'RThread', 'ThreadTariff')

    TABLE_NAMES = ['www_' + name.lower() for name in NAMES]
    CLASS_NAMES = [name + 'TimezoneLoader' for name in NAMES]

    LOADER_CLASSES = (
        type(class_name, (BaseDBLoader,), {'COLUMNS': [FIELDS[class_name]], 'TABLE_NAME': table_name})
        for class_name, table_name in zip(CLASS_NAMES, TABLE_NAMES)
    )

    THREAD_CACHE_LOADER = type('ThreadUidByIdLoader', (BaseDBLoader,),
                               {'COLUMNS': ['uid', 'id'], 'TABLE_NAME': 'www_rthread'})

    def __init__(self, connection):
        logging.info('Initializing...')
        self.loaders = [
            loader_class(MySQLConnector(connection, loader_class.TABLE_NAME, loader_class.COLUMNS))
            for loader_class in self.LOADER_CLASSES
        ]

        loader_class = RaspPrecacher.THREAD_CACHE_LOADER
        self.thread_cache_loader = loader_class(
            MySQLConnector(connection, loader_class.TABLE_NAME, loader_class.COLUMNS)
        )

        logging.info('Successfully initialized')

    def build_cache(self):
        logging.info('Getting info about timezones')
        timezones = set(
            row[0]
            for loader in self.loaders
            for row in loader.get_all()
            if row[0] is not None
        )
        result = TRaspPrecache()
        for num, timezone in enumerate(timezones):
            if timezone:
                result.timezone_ids[timezone] = num

        for left, right in self.thread_cache_loader.split_interval(20):
            for thread_uid, thread_id in self.thread_cache_loader.get_range(left, right):
                result.uid_to_id[thread_uid] = thread_id

        logging.info('Completed')
        return result

    def dump(self, output_stream):
        output_stream.write(self.build_cache().SerializeToString())


def initialize(filename):
    global CACHE
    with open(filename, 'rb') as input_file:
        CACHE.ParseFromString(input_file.read())


def run_normalize(host, user, password, db, directory):
    parameters = {
        'host': host,
        'user': user,
        'passwd': password,
        'db': db,
        'use_unicode': True,
        'charset': 'utf8'
    }

    with closing(MySQLdb.connect(**parameters)) as connection:
        normalizer = RaspPrecacher(connection)
        with open(join(directory, 'precache'), 'wb') as output:
            normalizer.dump(output)
