# coding: utf-8

import datetime
import itertools
import time
import inspect
import logging
from abc import ABCMeta, abstractmethod


logger = logging.getLogger(__name__)


class SchemaColumn(object):
    """
    Represents column in yt database
    :param schema - schema column definition for yt
    :param transformer - function to transform data from batch and creative to the one or more yt rows. As a default
    takes batch.get(yt_column_name)
    """

    def __init__(self, schema, transformer=None):
        self.schema = schema
        self.transformer = transformer or (lambda batch, creative: batch.get(schema["name"]))

    def transform(self, batch, creative):
        return self.transformer(batch, creative)

    def name(self):
        return self.schema["name"]


class Driver(object):
    """
    Tool working on one yt table. Can fetch data, transform it to the yt representation and dump data to yt
    :param yt:  yt client. FYI yt client is not thread safe
    :param mongo:  mongo client.
    :param schema: - iterable of SchemaColumn's
    """
    __metaclass__ = ABCMeta

    batch_size = 20000

    # Params below should be overridden in children classes
    schema = None  # Data schema
    yt_dump_dir = None  # Path to dir on YT where to store dumps

    date_fmt = "%Y-%m-%d"

    cleanup_ignore_x_first_dumps = 7

    def __init__(self, mongo, yt_client, lpc_psycopg2_connection, yt_dump_dir=None, overwrite_existing=False):
        """
        :param mongo:
        :param yt_client:
        :param lpc_psycopg2_connection: psycopg2 connection to LPC's DB
        :param yt_dump_dir:
        :param overwrite_existing:
        """
        self.mongo = mongo
        self.yt_client = yt_client
        self.lpc_psycopg2_connection = lpc_psycopg2_connection

        if yt_dump_dir:
            self.yt_dump_dir = yt_dump_dir

        self.table_path = "/".join([self.yt_dump_dir, self.generate_table_name()])

        self.overwrite_existing = overwrite_existing

    def generate_table_name(self):
        return time.strftime(self.date_fmt)

    def dump_exist(self):
        return self.yt_client.exists(self.table_path)

    def remove_table(self):
        self.yt_client.remove(self.table_path, recursive=False, force=True)

    def create_table(self):
        """
        Creates yt static table
        :return:
        """

        self.yt_client.create(
            "table",
            self.table_path,
            recursive=True,
            attributes={"schema": map(lambda column: column.schema, self.schema)},
        )

    @abstractmethod
    def fetch(self):
        """
        fetches data from mongo db
        :return: <iterable collection>
        """
        return []

    @abstractmethod
    def transform(self, batches):
        """
        transforms mongo objects to yt objects. Each Mongo object may be converted to several yt objects
        :param batches: list of mongo objects
        :return: list of rows
        """
        return []

    def dump(self, rows):
        """
        saves data to yt
        :param rows: list of yt rows
        :return:
        """
        # import not on the top level to prevent sandbox build issues
        from yt import wrapper as yt

        # avoid issues with retries
        assert not inspect.isgenerator(rows)

        # yt has internal retries. But sometimes they sucks in case when job could have been done. So...
        max_tries = 10
        for try_ in range(1, max_tries+1):
            try:
                self.yt_client.write_table(
                    self.yt_client.TablePath(self.table_path, append=True),
                    rows,
                    # TODO: yson?
                    format=yt.JsonFormat(attributes={"encode_utf8": False}),
                    force_create=False,
                )
                return
            except yt.YtHttpResponseError as exc:
                logger.warning(
                    "[%s/%s] %s suffering from unknown YT error trying to save: %s",
                    try_, max_tries, type(self).__name__, exc,
                )

                if try_ == max_tries or "Invalid type of column" in str(exc):
                    logger.error("cannot save %s rows to %s, all retries exceeded", len(rows), self.table_path)
                    raise

                # timeout between retries
                time.sleep(try_ ** 2)

    @staticmethod
    def get_count(iterable):
        """
        Get num fetched rows to dump.

        Default implementation for pymongo cursor
        """
        return iterable.count()

    def get_next_batch(self, iterable):
        """
        Get next batch to dump.

        Default implementation for pymongo cursor
        """
        return list(itertools.islice(iterable, self.batch_size))

    def run(self):
        """Main method to run all process"""
        with self.yt_client.Transaction():
            if self.dump_exist():
                if self.overwrite_existing:
                    self.remove_table()
                else:
                    return

            self.create_table()

            entities_cursor = self.fetch()

            # dumping data in batches of batch_size
            logger.info("get_count started")
            count = self.get_count(entities_cursor)
            logger.info("get_count finished")
            logger.info("get_next_batch started")
            entities = self.get_next_batch(entities_cursor)
            logger.info("get_next_batch finished")

            dumped = 0
            while entities:
                rows = self.transform(entities)
                if not rows:
                    break
                logger.info("dump started")
                self.dump(rows)
                logger.info("dump finished")

                dumped += len(rows)
                logger.info("Dumped %s/%s rows...", dumped, count)
                entities = self.get_next_batch(entities_cursor)

            self.yt_client.run_merge(self.table_path, self.table_path, spec={"combine_chunks": True})

            self.cleanup()

    def cleanup(self):
        """Delete obsolete tables.

        Preserve cleanup_ignore_x_first_dumps dumps. For other dumps preserve only one dump per month.
        """

        # iterate over sorted tables, preserve first for each month and delete others

        tables = sorted(
            self.yt_client.list(self.yt_dump_dir),
            key=lambda x: datetime.datetime.strptime(x, self.date_fmt),
        )[:-self.cleanup_ignore_x_first_dumps]

        cleanup_month = None

        for table in tables:
            date = datetime.datetime.strptime(table, self.date_fmt).date()

            if cleanup_month == date.replace(day=1):
                yt_path = "{}/{}".format(self.yt_dump_dir, table)
                logger.info("Table to delete during cleanup: %s", yt_path)
                self.yt_client.remove(yt_path)
            else:
                cleanup_month = date.replace(day=1)
