from airflow.hooks.dbapi_hook import DbApiHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults


class GenericDbTransfer(BaseOperator):
    """
    Moves data from a DB connection to another, assuming that they both
    are DbApiHooks.

    This is meant to be used on small-ish datasets that fit in memory.
    Params:
     * sql: (str) SQL query to execute against the source database.
     * query_params: (dict) The parameters to render the retrieval query with.
     * destination_table: (str) target table where to insert the results.
     * source_conn_id: (str) source connection.
     * destination_conn_id: (str) source connection.
     * target_fields: (list) columns to be inserted (default None, does positional insert).
     * preoperator: (str or list) sql statement to be executed prior to inserting the data.
     * preoperator_params: (dict) The parameters to render the preprocessor query with.
    """

    template_fields = ('sql', 'destination_table', 'preoperator', 'query_params', 'preoperator_params')
    template_ext = ('.sql', '.hql',)
    ui_color = '#b0f07c'

    @apply_defaults
    def __init__(
            self,
            *args,
            sql,
            destination_table,
            source_conn_id,
            destination_conn_id,
            preoperator=None,
            target_fields=None,
            query_params=None,
            preoperator_params=None,
            **kwargs):
        super(GenericDbTransfer, self).__init__(*args, **kwargs)
        self.sql = sql
        self.destination_table = destination_table
        self.source_conn_id = source_conn_id
        self.destination_conn_id = destination_conn_id
        self.preoperator = preoperator
        self.query_params = query_params
        self.preoperator_params = preoperator_params
        self.target_fields = target_fields

    def execute(self, context):  # pylint: disable=unused-argument
        source_hook = DbApiHook.get_hook(self.source_conn_id)

        self.log.info("Extracting data from %s", self.source_conn_id)
        self.log.info("Executing: \n %s", self.sql)
        results = source_hook.get_records(self.sql, parameters=self.query_params)

        if results is not None and len(results) > 0:  # pylint: disable=len-as-condition
            destination_hook = DbApiHook.get_hook(self.destination_conn_id)
            if self.preoperator:
                self.log.info("Running preoperator")
                self.log.info(self.preoperator)
                destination_hook.run(self.preoperator, parameters=self.preoperator_params)

            self.log.info("Inserting rows into %s", self.destination_conn_id)
            destination_hook.insert_rows(table=self.destination_table, rows=results, target_fields=self.target_fields)
        else:
            self.log.info("No rows selected from source database.")
