from crypta.lib.python.bt.workflow import (
    IndependentTask,
    Parameter,
)
from crypta.lib.python.bt.tasks import (
    YtTask,
)
from crypta.lab.lib.common import (
    WithApi,
    _foreign
)
from crypta.lib.python.swagger import _to_proto

from crypta.lab.proto.view_pb2 import (
    TSampleView,
)
from crypta.lab.proto.other_pb2 import (
    TSourceDestinationState,
)
import crypta.lab.proto.view_pb2 as ViewOptions
import crypta.lab.proto.matching_pb2 as MatchingOptions
import crypta.lab.proto.hashing_pb2 as Hashing
from crypta.lab.lib.native_operations import (
    TComputeMatchingIdMapper,
    TJoinCryptaIDReducer,
    TJoinIdentifiersReducer,
    TJoinIdentifiersStatisticReducer
)
import crypta.lib.python.bt.conf.conf as conf

import logging
logger = logging.getLogger(__name__)


MINIMAL_SIZE = 0


class Match(YtTask, IndependentTask, WithApi):

    sample_id = Parameter()

    src_view = Parameter()

    dst_view = Parameter()

    def update_sample_view_state(self, id, view_id, state):
        return self.api.lab.updateSampleViewState(id=id, view_id=view_id, state=state).result()

    def get_min_size(self, id, view_id):
        return self.api.lab.getSampleViewMinSize(id=id, view_id=view_id).result()

    def get_view(self, view_id):
        try:
            view = self.api.lab.getSampleView(
                id=self.sample_id,
                view_id=view_id,
            ).result()
            assert view
            return _to_proto(TSampleView, view)
        except Exception as e:
            if "status_code" in dir(e) and e.status_code == 404:
                return None
            raise e

    def get_schema(self, dst_view, src_view):
        output_schema = []
        if dst_view.Options.Matching.IncludeOriginal:
            schema = self.yt.get_attribute(src_view.Path, "schema")
            for item in schema:
                if "sort_order" in item:
                    item.pop("sort_order")
                output_schema.append(item)
        return output_schema

    def get_types(self, view_type):
        types = {
            MatchingOptions.LAB_ID_YANDEXUID: ["yandexuid"],
            MatchingOptions.LAB_ID_MM_DEVICE_ID: ["mm_device_id"],
            MatchingOptions.LAB_ID_EMAIL: ["email"],
            MatchingOptions.LAB_ID_PHONE: ["phone"],
            MatchingOptions.LAB_ID_PUID: ["puid"],
            MatchingOptions.LAB_ID_LOGIN: ["login"],
            MatchingOptions.LAB_ID_UUID: ["uuid"],
            MatchingOptions.LAB_ID_CRYPTA_ID: ["crypta_id"],
            MatchingOptions.LAB_ID_IDFA_GAID: ["idfa", "gaid"]
        }

        result = types.get(view_type)
        if result is None:
            raise ValueError("invalid value of scope view in matching options")

        return result

    def get_paths(self, base_path, src_view, dst_view, src_md5):
        type_in = self.get_types(src_view.Options.Matching.IdType)
        type_out = self.get_types(dst_view.Options.Matching.IdType)
        assert type_in != type_out

        paths = []
        for ti in type_in:
            if src_md5:
                ti += "_md5"

            for to in type_out:
                path = None
                if ti == "crypta_id" or to == "crypta_id":
                    path = base_path + "/" + ti + "/" + to
                else:
                    path = base_path + "/" + ti + "/direct/" + to

                paths.append(path)

        return paths

    def run(self, **kwargs):
        vertices = conf.paths.graph.vertices_no_multi_profile
        vertices_by_crypta_id = conf.paths.graph.vertices_by_crypta_id

        src_view = self.get_view(self.src_view)
        dst_view = self.get_view(self.dst_view)

        assert src_view.Options.Matching.IdType != dst_view.Options.Matching.IdType
        assert dst_view.Options.Matching.Key
        logger.info('Matching %s into %s', src_view, dst_view)
        # Need to start indevicebytypes after complete
        logger.info(self.update_sample_view_state(
            id=self.sample_id,
            view_id=self.dst_view,
            state='PROCESSING',
        ))

        state = TSourceDestinationState(
            Source=src_view.Options.Matching,
            Destination=dst_view.Options.Matching,
        ).SerializeToString()

        schema = self.get_schema(dst_view, src_view)
        if dst_view.Type == ViewOptions.MATCHING:
            schema.append({"name": dst_view.Options.Matching.Key, "type": "string"})
        elif dst_view.Type == ViewOptions.CRYPTA_ID_STATISTICS:
            schema.extend([{"name": dst_view.Options.Matching.Key, "type": "any"}, {"name": "ccIdType", "type": "string"}])
        else:
            raise ValueError("Unsupported destination view type: %s" % dst_view.Type)
        logger.info(schema)
        attributes = {"schema": schema}
        self.yt.create(
            'table',
            dst_view.Path,
            attributes=attributes,
            force=True
        )

        if dst_view.Options.Matching.Scope == MatchingOptions.IN_DEVICE:
            path_to_indevice_tables = self.get_paths(conf.paths.indevicebytypes.base_path, src_view, dst_view, src_view.Options.Matching.HashingMethod == Hashing.HM_MD5)
            logger.info("indevice matching in %s", path_to_indevice_tables)
            with self.yt.TempTable(prefix='normalize_value_') as normalize_value, self.yt.TempTable(prefix='indevice_table_') as tmp_indevice_table:
                self.native_map(
                    TComputeMatchingIdMapper,
                    source=src_view.Path,
                    destination=normalize_value,
                    state=state
                )
                self.sort(
                    source=normalize_value,
                    destination=normalize_value,
                    sort_by="id",
                )

                if len(path_to_indevice_tables) == 1:
                    real_indevice_table = path_to_indevice_tables[0]
                else:
                    self.yt.run_merge(path_to_indevice_tables, tmp_indevice_table)
                    real_indevice_table = tmp_indevice_table

                self.native_join_reduce(
                    TJoinIdentifiersReducer,
                    source=[_foreign(real_indevice_table), normalize_value],
                    destination=str(dst_view.Path),
                    join_by='id',
                    state=state
                )
        elif dst_view.Options.Matching.Scope == MatchingOptions.CROSS_DEVICE:
            with self.yt.TempTable(prefix='with_id_value_') as with_id_value, self.yt.TempTable(prefix='crypta_id_') as crypta_ids:
                self.native_map(
                    TComputeMatchingIdMapper,
                    source=src_view.Path,
                    destination=with_id_value,
                    state=state
                )
                if src_view.Options.Matching.IdType == MatchingOptions.LAB_ID_CRYPTA_ID:
                    source = [_foreign(vertices_by_crypta_id), with_id_value]
                    mapped_row_name = "cryptaId"
                else:
                    source = [_foreign(vertices), with_id_value]
                    mapped_row_name = "id"

                if dst_view.Options.Matching.IdType == MatchingOptions.LAB_ID_CRYPTA_ID:
                    crypta_ids = str(dst_view.Path)

                if src_view.Options.Matching.IdType == MatchingOptions.LAB_ID_CRYPTA_ID and dst_view.Type == ViewOptions.CRYPTA_ID_STATISTICS:
                    crypta_ids = with_id_value
                else:
                    self.sort(
                        source=with_id_value,
                        destination=with_id_value,
                        sort_by=mapped_row_name,
                    )
                    self.native_join_reduce(
                        TJoinCryptaIDReducer,
                        source=source,
                        destination=crypta_ids,
                        join_by=mapped_row_name,
                        state=state
                    )
                if dst_view.Options.Matching.IdType != MatchingOptions.LAB_ID_CRYPTA_ID:
                    self.sort(
                        source=crypta_ids,
                        destination=crypta_ids,
                        sort_by='cryptaId',
                    )

                    if dst_view.Type == ViewOptions.MATCHING:
                        identifier_reducer = TJoinIdentifiersReducer
                    elif dst_view.Type == ViewOptions.CRYPTA_ID_STATISTICS:
                        identifier_reducer = TJoinIdentifiersStatisticReducer
                    else:
                        raise ValueError("Unsupported destination view type: %s" % dst_view.Type)

                    self.native_join_reduce(
                        identifier_reducer,
                        source=[_foreign(vertices_by_crypta_id), crypta_ids],
                        destination=str(dst_view.Path),
                        join_by='cryptaId',
                        state=state
                    )
        else:
            raise ValueError("Invalid value of scope in matching options")

        min_size = self.get_min_size(
            id=self.sample_id,
            view_id=self.dst_view,
        )
        self.yt.run_merge(dst_view.Path, dst_view.Path, spec=dict(combine_chunks=True))
        if self.yt.row_count(dst_view.Path) > min_size:
            logger.info(self.update_sample_view_state(
                id=self.sample_id,
                view_id=self.dst_view,
                state='READY',
            ))
        else:
            self.yt.remove(dst_view.Path)
            logger.info(self.update_sample_view_state(
                id=self.sample_id,
                view_id=self.dst_view,
                state='NOT_ENOUGH_DATA',
            ))
