from bisect import bisect_left
from random import randint
from time import sleep

import os.path as os_path
import uuid as uuid_rand

from yt.wrapper import yson

import yt.wrapper as yt_general


class RevealTable(object):
    def __init__(self, left_col, right_col, input_columns):
        self.inc = input_columns
        self.left_col = left_col
        self.right_col = right_col

    def __call__(self, record):
        yield {
            self.left_col: record[self.inc[0]],
            self.right_col: record[self.inc[1]]
        }
        yield {
            self.left_col: record[self.inc[1]],
            self.right_col: record[self.inc[0]]
        }


class DistinctRecords(object):
    def __init__(self, left_col, right_col):
        self.left_col = left_col
        self.right_col = right_col

    def __call__(self, keys, _):
        yield {
            self.left_col: keys[self.left_col],
            self.right_col: keys[self.right_col]
        }
        return


def create_uint64_two_column_table(tbl, columns, yt):
    schema_attr = [
        {'name': columns[0], 'type': 'uint64'},
        {'name': columns[1], 'type': 'uint64'}
    ]
    schema_yson = yson.YsonList(schema_attr)
    schema_yson.attributes["strict"] = True
    if yt.exists(tbl):
        yt.remove(tbl, force=True)
    yt.create_table(
        tbl,
        attributes={'schema': schema_yson, 'optimize_for': 'scan'},
        ignore_existing=True
    )


@yt_general.aggregator
class FileCrafter(object):
    def __init__(
            self,
            path, local_path,
            left_col, right_col,
            transaction_id
    ):
        self.path = path
        self.local_path = local_path
        self.left_col = left_col
        self.right_col = right_col
        self.transaction_id = transaction_id

    def __call__(self, records):
        min_numb = 'f' * 32
        last_tpl = tuple()
        with open(self.local_path, 'wb') as file_pointer:
            for record in records:
                tpl = (record[self.left_col], record[self.right_col])
                if last_tpl != tpl:
                    last_tpl = tpl
                    hex_int = ('%016x%016x' % tpl)
                    file_pointer.write(hex_int.decode('hex'))
                    if hex_int < min_numb:
                        min_numb = hex_int
        with open(self.local_path, 'rb') as file_pointer:
            with yt_general.Transaction(
                    transaction_id=self.transaction_id):
                while True:
                    try:
                        sleep(randint(0, 15))
                        yt_general.write_file(
                            self.path + min_numb,
                            file_pointer
                        )
                        break
                    except:
                        continue
        return


@yt_general.reduce_aggregator
class GraphConnectivityFinder(object):
    def __init__(self, local_path,
                 left_col, right_col,
                 new_left_col, new_right_col):
        self.local_path = local_path
        self.left_col = left_col
        self.right_col = right_col
        self.new_left_col = new_left_col
        self.new_right_col = new_right_col
        self.file_pointer = None
        self.file_size = 0
        self.pairs_count = 0
        self.pre_search_list = None
        self.pre_search_len = 1024
        self.step = 0

    def decode_int_fp(self):
        return int(
            self.file_pointer.read(8).encode('hex'),
            16
        )

    def init_reader(self):
        self.file_pointer = open(self.local_path, 'rb')
        self.file_size = os_path.getsize(self.local_path)
        self.pairs_count = self.file_size // 16
        self.pre_search_list = [0] * self.pre_search_len
        self.step = self.pairs_count // self.pre_search_len
        for i in xrange(self.pre_search_len):
            self.file_pointer.seek(self.step * i * 16)
            self.pre_search_list[i] = self.decode_int_fp()
        self.file_pointer.seek(0)
        assert sorted(self.pre_search_list) == self.pre_search_list
        return

    def get_all_pairs(self, vertice):
        pos_index = bisect_left(
            self.pre_search_list,
            vertice,
            0,
            self.pre_search_len
        )
        if pos_index >= self.pre_search_len:
            assert vertice > self.pre_search_list[-1]
            left, right = (pos_index - 1) * self.step, self.pairs_count
            return self.find_exact_index(
                left, right, vertice)
        assert vertice <= self.pre_search_list[pos_index]
        if self.pre_search_list[pos_index] == vertice:
            return self.read_all_pairs(
                pos_index * self.step, vertice)
        assert self.pre_search_list[pos_index-1] < vertice
        assert vertice < self.pre_search_list[pos_index]
        left = (pos_index - 1) * self.step
        right = pos_index * self.step
        return self.find_exact_index(
            left, right, vertice)

    def find_exact_index(self, left, right, vertice):
        self.file_pointer.seek(left * 16)
        assert self.decode_int_fp() < vertice
        if right == self.pairs_count:
            self.file_pointer.seek((right-1) * 16)
        else:
            self.file_pointer.seek(right * 16)
        assert vertice <= self.decode_int_fp()
        while left < right:
            mid = (left + right) // 2
            self.file_pointer.seek(mid * 16)
            midval = self.decode_int_fp()
            if midval < vertice:
                left = mid + 1
            elif midval > vertice:
                right = mid
            else:
                return self.read_all_pairs(
                    mid, vertice)
        self.file_pointer.seek((left-1) * 16)
        assert self.decode_int_fp() < vertice
        self.file_pointer.seek(left * 16)
        assert vertice < self.decode_int_fp()
        raise Exception("Impossible %s" % vertice)

    def read_all_pairs(self, start_pos, vertice):
        vertices = list()
        cur_pos = start_pos
        while cur_pos < self.pairs_count:
            self.file_pointer.seek(cur_pos * 16)
            cur_pos += 1
            vert = self.decode_int_fp()
            if vert != vertice:
                break
            vertices.append(self.decode_int_fp())
        assert vertices
        cur_pos = start_pos - 1
        while cur_pos >= 0:
            self.file_pointer.seek(cur_pos * 16)
            cur_pos -= 1
            vert = self.decode_int_fp()
            if vert != vertice:
                break
            vertices.append(self.decode_int_fp())
        return vertices

    def __call__(self, row_groups):
        self.init_reader()
        all_verts = set()
        for keys, records in row_groups:
            if keys[self.left_col] in all_verts:
                continue
            vertices = set()
            vertices.add(keys[self.left_col])
            for record in records:
                vertices.add(record[self.right_col])
            work_to_do = set(vertices)
            while work_to_do:
                vertice_to_check = work_to_do.pop()
                all_pairs = set(self.get_all_pairs(vertice_to_check))
                new_pairs = all_pairs - vertices
                vertices |= new_pairs
                work_to_do |= new_pairs
            group_elem = min(vertices)
            all_verts |= vertices
            for vertice in vertices:
                yield {
                    self.new_left_col: vertice,
                    self.new_right_col: group_elem
                }
        self.file_pointer.close()
        return


def find_graph_connectivity(
        pairs_table, destination_table,
        input_columns, output_columns,
        yt=yt_general):

    if len(input_columns) != 2:
        raise Exception(
            "Invalid format of input_columns argument, len must be equal 2")
    if len(output_columns) != 2:
        raise Exception(
            "Invalid format of output_columns argument, len must be equal 2")
    vert_column, group_column = output_columns
    left_column = 'left'
    right_column = 'right'

    schema = yt.get(pairs_table + '/@schema')
    sorted_schema = sorted(
        schema,
        key=lambda x: (
            input_columns.index(x['name'])
            if x['name'] in input_columns
            else 666
        )
    )
    if not (sorted_schema[0]['name'] == input_columns[0] and
            sorted_schema[1]['name'] == input_columns[1] and
            sorted_schema[0]['type'] == 'uint64' and
            sorted_schema[1]['type'] == 'uint64'):
        raise Exception("Incorrect format of input table")
    if not yt.get(pairs_table + "/@sorted"):
        raise Exception("Input table must be sorted")
    if yt.get(pairs_table + "/@sorted_by") != input_columns:
        raise Exception("Input table must be sorted by input columns")

    yt.config['allow_http_requests_to_yt_from_job'] = True
    temp_folder = '//tmp/' + uuid_rand.uuid4().__str__() + '/'
    all_file = temp_folder + 'all_file'
    local_name = 'dbfile'

    with yt.Transaction() as transaction,\
            yt.TempTable() as all_pairs_table,\
            yt.TempTable() as dummy_dest_table,\
            yt.TempTable() as destination_not_distincted:
        yt.mkdir(temp_folder.rstrip('/'))
        # creating all pairs (direct and reversial) and sort it
        create_uint64_two_column_table(
            all_pairs_table, [left_column, right_column], yt)
        create_uint64_two_column_table(
            dummy_dest_table, [left_column, right_column], yt)

        yt.run_map(
            RevealTable(left_column, right_column, input_columns),
            pairs_table,
            all_pairs_table,
            spec={
                "data_size_per_job": 1024 * 1024 * 4
            }
        )
        yt.run_sort(
            all_pairs_table,
            sort_by=[left_column, right_column]
        )

        counter = 0
        while True:
            counter += 1
            try:
                # create some files which is representation of this table
                yt.run_map(
                    FileCrafter(
                        temp_folder, 'tmpfile',
                        left_column, right_column,
                        transaction.transaction_id
                    ),
                    all_pairs_table,
                    dummy_dest_table,
                    ordered=True,
                    spec={
                        "tmpfs_path": ".",
                        "copy_files": True,
                        "data_size_per_job": 1024 * 1024 * 4
                    }
                )
                # and make one big unite file
                files = [
                    temp_folder + f
                    for f in sorted(
                        yt.list(temp_folder.rstrip('/'))
                    )
                ]
                if yt.exists(all_file):
                    yt.remove(all_file, force=True)
                yt.concatenate(
                    files,
                    all_file
                )
                for tmp_file in files:
                    yt.remove(tmp_file, force=True)

                create_uint64_two_column_table(
                    destination_not_distincted, [vert_column, group_column], yt)
                create_uint64_two_column_table(
                    destination_table, [vert_column, group_column], yt)

                # main process
                yt.run_reduce(
                    GraphConnectivityFinder(
                        local_name,
                        input_columns[0], input_columns[1],
                        vert_column, group_column
                    ),
                    pairs_table,
                    destination_not_distincted,
                    reduce_by=input_columns[0],
                    yt_files=[
                        yt.TablePath(
                            all_file,
                            attributes={'file_name': local_name}
                        )
                    ],
                    # thats enough if len(pairs_table) < 256M pairs
                    memory_limit=1024 * 1024 * 1024 * (8+2),
                    spec={
                        "tmpfs_path": ".",
                        "copy_files": True,
                        "data_size_per_job": 1024 * 1024 * 2
                    }
                )
                break
            except:
                if counter > 10:
                    raise Exception("Too many fails in gcf operation")
                continue

        # distincting and sorting
        yt.run_sort(
            destination_not_distincted,
            sort_by=[vert_column, group_column]
        )
        yt.run_reduce(
            DistinctRecords(vert_column, group_column),
            destination_not_distincted,
            destination_table,
            reduce_by=[vert_column, group_column],
            spec={
                "data_size_per_job": 1024 * 1024 * 4
            }
        )
        yt.run_sort(
            destination_table,
            sort_by=vert_column
        )
        yt.remove(temp_folder.rstrip('/'), recursive=True)
    return


if __name__ == "__main__":
    LAMB = lambda module: 'hashlib' not in getattr(module, '__name__', '')
    yt_general.config['pickling']['module_filter'] = LAMB

    find_graph_connectivity(
        "//home/crypta/team/cberg/fuzzy/yuid_pairs_result",
        "//home/crypta/team/cberg/fuzzy/yuid_pairs_groupped.tmp",
        ['yandexuid_left', 'yandexuid_right'],
        ['vertice', 'group']
    )
