import os
import sys
import logging

from crypta.lib.python.native_yt import run_native_map_reduce
import library.python.resource as rs
import yt.wrapper as yt

logger = logging.getLogger(__name__)
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.INFO)
logger.addHandler(handler)

MB = 1 << 20
SPEC = {'data_size_per_job': 128 * MB}


def sandbox_mark(title):
    if os.getenv('YQL_MARK'):
        # title should has YQL substring, so if title is None or empty, set 'Crypta YQL'
        return '{title} {mark}'.format(title=(title or 'Crypta MRCC'), mark=os.getenv('YQL_MARK'))
    return title


class MRConnectedComponentsYT(object):
    """ Helper to find connected components on undirected graph in MR.
    Verbose:
    - True: report the count of changes. Have to run MR operation.
    - False: reports the count of rows in changes table.
    """

    def __init__(self, yt_client, verbose=False):
        self.yt = yt_client
        self.verbose = verbose

    def count_changes(self, changes_path):
        if self.verbose:
            return self.verbose_count_changes(changes_path)
        return int(self.yt.get("{path}/@row_count".format(path=changes_path)))

    def verbose_count_changes(self, changes_path):

        @yt.aggregator
        def mapper(rows):
            change_type = "aggregate"
            change_count = 0
            for row in rows:
                change_count += int(row.get("cnt", 0))
            yield {"type": change_type, "cnt": change_count}

        def reducer(key, rows):
            change_count = 0
            for row in rows:
                change_count += int(row.get("cnt", 0))
            yield {"cnt": change_count}

        changes = 0
        with self.yt.TempTable() as agg:
            self.yt.run_map_reduce(mapper, reducer,
                                   changes_path, agg,
                                   reduce_by=["type"])
            for row in self.yt.read_table(agg):
                changes += int(row.get("cnt", 0))
        return changes

    def find_connected_components(self, graph_path, max_iter=15):
        """Debug mode creates an output for every iteration"""
        logger.info("Yt config proxy <%s> pool <%s>", self.yt.proxy, self.yt.pool)
        with self.yt.TempTable() as changes_path:
            for iteration in xrange(max_iter):
                logger.info("Starting iteration %s", iteration + 1)
                run_native_map_reduce(
                    mapper_name="NStarOperations::TLargeStarMapper",
                    reducer_name="NStarOperations::TLargeStarReducer",
                    source=graph_path,
                    destination=[graph_path, changes_path],
                    reduce_by=["u"],
                    sort_by=["u", "v"],
                    proxy=self.yt.proxy,
                    transaction=self.yt.transaction_id,
                    token=self.yt.token,
                    pool=self.yt.pool,
                    title=sandbox_mark('Large Star'),
                    spec=SPEC)
                large_star_changes = self.count_changes(changes_path)
                run_native_map_reduce(
                    "NStarOperations::TSmallStarMapper",
                    "NStarOperations::TSmallStarReducer",
                    graph_path,
                    [graph_path, changes_path],
                    reduce_by=["u"],
                    sort_by=["u", "v"],
                    proxy=self.yt.proxy,
                    transaction=self.yt.transaction_id,
                    token=self.yt.token,
                    pool=self.yt.pool,
                    title=sandbox_mark('Small Star'))
                small_star_changes = self.count_changes(changes_path)
                total_changes = large_star_changes + small_star_changes
                converged = (total_changes == 0)
                logger.info("Iteration %s complete. "
                            "Large star changes: %d. Small star changes: %d",
                            iteration + 1,
                            large_star_changes,
                            small_star_changes)
                if converged:
                    logger.info("Converged")
                    return True
        logger.error("Not converged")
        return False


class MRConnectedComponentsYQL(object):
    """ Helper to find connected components on undirected graph in MR."""

    def __init__(self, yql):
        self.yql = yql

    @staticmethod
    def get_query(input_path, output_path):
        boilerplate = rs.find("/yql/iteration.sql")
        query = MRConnectedComponentsYQL._render_query(
            boilerplate, dict(source=input_path, destination=output_path))
        return query

    def find_connected_components(self, graph_path, max_iter=15):
        """Debug mode creates an output for every iteration"""
        query = self.get_query(graph_path, graph_path)
        for iteration in xrange(max_iter):
            logger.info("Starting iteration %s", iteration + 1)
            self.yql.title = 'YQL MRCC iteration {}'.format(iteration+1)
            table = self.yql.execute(query, syntax_version=1)[0]
            row = table.rows[0]
            total_changes, large_star_changes, small_star_changes = \
                [int(item) for item in row]
            converged = (total_changes == 0)
            if converged:
                logger.info("Converged")
                return True
            logger.info("Iteration %s complete. "
                        "Large star changes: %d. Small star changes: %d",
                        iteration + 1, large_star_changes, small_star_changes)
        logger.error("Not converged")
        return False

    @staticmethod
    def _render_query(template, context):
        """ Simple render query template """
        return template.format(**context)


class ExtendedMRConnectedComponentsYT(MRConnectedComponentsYT):

    """ Allow to compute connected components with extra params """

    def find_connected_components(
            self, source, destination=None, u_name='u', v_name='v', component_id='component_id', max_iter=15):
        """ Wrap base find_connected_components to allow custom columns name and separate tables for input/output """
        if destination is None:
            destination = source

        with self.yt.TempTable() as intermediate:
            self._starting_components(
                source=source, intermediate=intermediate, destination=destination,
                u_name=u_name, v_name=v_name, component_id=component_id)

            converged = super(ExtendedMRConnectedComponentsYT, self).find_connected_components(
                graph_path=intermediate, max_iter=max_iter)

            self._finising_components(
                source=source, intermediate=intermediate, destination=destination,
                u_name=u_name, v_name=v_name, component_id=component_id)

        return converged

    def _starting_components(self, source, intermediate, destination, u_name, v_name, component_id):
        logger.info("Starting mrcc")

        def mapper(row):
            yield {
                'u': min((row[u_name], row[v_name])),
                'v': max((row[u_name], row[v_name]))
            }

        self.yt.run_map(mapper, [source], [intermediate])

    def _finising_components(self, source, intermediate, destination, u_name, v_name, component_id):
        logger.info("Finishing mrcc")

        def mapper(row):
            yield {
                u_name: row['u'],
                v_name: row['v'],
                component_id: row['u'],  # TODO: may be add equivalent of PAGE_RANK()
            }

        schema = [
            {'name': u_name, 'type': 'string'},
            {'name': v_name, 'type': 'string'},
            {'name': component_id, 'type': 'string'},
        ]
        self.yt.remove(destination, force=True)
        self.yt.create('table', destination, attributes={'schema': schema, 'optimize_for': 'scan', })
        self.yt.run_map(mapper, [intermediate], [destination])
        self.yt.run_sort(destination, sort_by=[component_id, u_name, v_name])


class ExtendedMRConnectedComponentsYQL(MRConnectedComponentsYQL):

    """ Allow to compute connected components with extra params """

    STARTING_TEMPLATE = '/yql/starting.sql'
    FINISHING_TEMPLATE = '/yql/finishing.sql'

    def find_connected_components(
            self, source, destination=None, u_name='u', v_name='v', component_id='component_id', max_iter=15):
        """ Wrap base find_connected_components to allow custom columns name and separate tables for input/output """
        if destination is None:
            destination = source
        intermediate = self.get_temporary_table(destination)

        self._starting_components(
            source=source, intermediate=intermediate, destination=destination,
            u_name=u_name, v_name=v_name, component_id=component_id)
        converged = super(ExtendedMRConnectedComponentsYQL, self).find_connected_components(
            graph_path=intermediate, max_iter=max_iter)
        self._finising_components(
            source=source, intermediate=intermediate, destination=destination,
            u_name=u_name, v_name=v_name, component_id=component_id)

        return converged

    @staticmethod
    def get_temporary_table(destination):
        return '{}_temporary'.format(destination)

    def _starting_components(self, **kwargs):
        logger.info("Starting mrcc")
        template = rs.find(self.STARTING_TEMPLATE)
        query = self._render_query(template, kwargs)
        self.yql.title = 'YQL MRCC starting'
        self.yql.execute(query, syntax_version=1)

    def _finising_components(self, **kwargs):
        logger.info("Finishing mrcc")
        template = rs.find(self.FINISHING_TEMPLATE)
        query = self._render_query(template, kwargs)
        self.yql.title = 'YQL MRCC finishing'
        self.yql.execute(query, syntax_version=1)


class VDataMRConnectedComponentsYQL(ExtendedMRConnectedComponentsYQL):

    """ Same as ExtendedMRConnectedComponentsYQL, but also keep column `v-data` from initial table """

    FINISHING_TEMPLATE = '/yql/finishing_vdata.sql'
