# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from airflow.hooks.dbapi_hook import DbApiHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults


class RunDbQuery(BaseOperator):
    """
    Run a set of DDL queries within the underlying DB, assuming that it is DbApiHooks.

    :param conn_id: db connection
    :type conn_id: str
    :param sql: SQL query to execute against the database
    :type sql: str
    :param query_params: The parameters to render the retrieval query with
    :type query_params: mapping or iterable
    """

    template_fields = ('conn_id', 'sql', 'query_params')
    template_ext = ('.sql', '.hql',)
    ui_color = '#b0f07c'

    @apply_defaults
    def __init__(
            self,
            *args,
            conn_id,
            sql,
            query_params=None,
            **kwargs):
        super(RunDbQuery, self).__init__(*args, **kwargs)
        self.conn_id = conn_id
        self.sql = sql
        self.query_params = query_params
        self.source_hook = DbApiHook.get_hook(self.conn_id)

    def execute(self, context):  # pylint: disable=unused-argument
        self.log.info("Running SQL in %s", self.conn_id)
        self.log.info("Executing: \n %s", self.sql)
        self.log.info("query_params: \n %s", self.query_params)
        self.source_hook.run(self.sql, autocommit=True, parameters=self.query_params)

    def get_query_results(self):
        self.log.info("Running SQL in %s", self.conn_id)
        self.log.info("Executing: \n %s", self.sql)
        self.log.info("query_params: \n %s", self.query_params)
        return self.source_hook.get_records(self.sql, parameters=self.query_params)
