import logging

from sandbox import common
from sandbox.projects.metrika.admins.dicts.lib import MetrikaBuildDictionaryCommon
from sandbox.projects.metrika.utils.base_metrika_task import with_parents


@with_parents
class MetrikaBuildDictionary(MetrikaBuildDictionaryCommon):

    @property
    def sql(self):
        if 'sql' in self.dict_config:
            return self.dict_config.sql
        sql = """
            SELECT
                %(fields)s
            FROM %(table)s
            """
        if self.db_type == 'mysql':
            fields = ", ".join("`%s`" % field for field in self.dict_config.fields)
        elif self.db_type == 'postgresql':
            fields = ", ".join("%s" % field for field in self.dict_config.fields)
        if 'where_clause' in self.dict_config:
            sql = sql + ' WHERE %(where_clause)s' % {'where_clause': self.dict_config.where_clause}

        return sql % {'fields': fields, 'table': self.dict_config.table}

    @property
    def db_type(self):
        db_type = self.dict_config.get("db_type", "mysql")
        if db_type not in ("mysql", "postgresql"):
            raise ValueError("Unsupported db_type: %s" % db_type)
        return db_type

    def make_dict(self):
        import metrika.pylib.escape_utils as mteu
        import psycopg2
        import pymysql

        if self.db_type == 'mysql':
            connection = pymysql.connect(
                db=self.dict_config.database,
                **self.dict_config.mysql
            )
        elif self.db_type == 'postgresql':
            connection = psycopg2.connect(
                database=self.dict_config.database,
                **self.dict_config.postgresql
            )

        try:
            logging.debug("Run sql:\n%s", self.sql)
            with connection.cursor() as cursor:
                cursor.execute(self.sql)
                records_fetched = 0
                with open(str(self.dict_file), 'wb') as f:
                    while 1:
                        data = cursor.fetchmany(10000)
                        records_fetched += len(data)
                        logging.debug("Fetched %d records..." % records_fetched)
                        if not data:
                            break

                        for line in data:
                            t = map(lambda x: mteu.escape_string(str(x)) if x is not None else '', line)
                            f.write("\t".join(t))
                            f.write("\n")
        except Exception as e:
            logging.exception("Got exception")
            raise common.errors.TaskFailure(str(e))
        else:
            logging.debug("Sql is finished")
        finally:
            connection.close()
