# -*- encoding: utf-8 -*-
import logging
from collections import namedtuple
from contextlib2 import closing

from pathlib2 import Path  # noqa

from travel.avia.dump_data.lib.model_classes import BaseModel
from travel.avia.dump_data.lib.mysql_connector import MysqlConnector  # noqa
from travel.library.python.dicts import file_util


class MysqlModel(BaseModel):
    FieldClass = namedtuple('FieldClass', ['proto_attribute', 'db_column'])

    def __init__(self, connector, name, db_table, proto_model):
        self.connector = connector  # type: MysqlConnector
        self._name = name
        self.db_table = db_table
        self.proto_model = proto_model
        self.fields = []

    def add_fields(self, proto_attribute, db_column):
        self.fields.append(
            self.FieldClass(
                proto_attribute,
                db_column
            )
        )

    @property
    def name(self):
        return self._name

    def dump_into_directory(self, directory):
        # type: (Path) -> None
        file_name = directory / self.get_output_file_name()

        with closing(self.connector.get_connection()) as connection:
            with closing(connection.cursor()) as cursor:
                with open(str(file_name), 'wb') as file:
                    self._dump_from_mysql_into_file(cursor, file)

        logging.info('Write %s', str(file_name))

    def _dump_from_mysql_into_file(self, cursor, file):
        count_row = 0
        for row in self._row_generator(cursor):
            count_row += 1
            proto = self._as_proto(row)
            file_util.write_binary_string(file, proto.SerializeToString())

        logging.info('Fetch %s rows for %s reference', count_row, self.name)

    def _row_generator(self, cursor):
        query = self._get_query()
        logging.debug('Execute query: %s', query)

        cursor.execute(query)
        for row in cursor:
            yield row

    def _get_query(self):
        fields = ', '.join([field.db_column for field in self.fields])
        return "select {fields} from {table}".format(
            fields=fields,
            table=self.db_table,
        )

    def _as_proto(self, row):
        proto = self.proto_model()

        for field in self.fields:
            value = row[field.db_column]
            if value:
                setattr(proto, field.proto_attribute, value)

        return proto
