from cached_property import cached_property
import logging

import library.python.resource as rs
import tvmauth
import yt.wrapper as yt

from crypta.lab.lib import common as utils
from crypta.lab.lib.describe import (
    _direct_join_mapper,
    _direct_join_reducer,
)
from crypta.lab.lib.native_operations import (
    TComputeMatchingIdMapper,
    TJoinCryptaIDReducer,
    TJoinIdentifiersReducer,
    TRenameIdentifierMapper,
    TValidateMapper,
    TSubsamplesJoinReducer,
)
from crypta.lab.lib.specs import (
    YtSpecs,
)
from crypta.lab.proto import sample_pb2
import crypta.lab.proto.matching_pb2 as MatchingProto
import crypta.lab.proto.other_pb2 as OtherProto
import crypta.lib.proto.identifiers.id_pb2 as IdProto
from crypta.lib.python import (
    templater,
    time_utils
)
from crypta.lib.python.bt import (
    tasks,
    workflow
)
import crypta.lib.python.bt.conf.conf as conf
import crypta.lib.python.bt.workflow.targets.table as target_table
from crypta.lib.python.swagger import _to_proto, swagger
import crypta.lib.python.tvm.helpers as tvm_helpers
from crypta.lib.python.yt import schema_utils
from crypta.siberia.bin.common.describing.experiment.proto import describing_experiment_pb2
from crypta.siberia.bin.common.siberia_client import SiberiaClient

logger = logging.getLogger(__name__)

GROUP_NAME_ID = 'group_name'
SAMPLE_SIZE = 'sample_size'
SUBSAMPLE_NODE_NAME = 'subsamples'
SUBSAMPLES_INFO_TABLE_NAME = 'subsamples_info'
MAX_DESCRIBING_IDS_COUNT = 100 * 1000
BANNED_SYMBOLS = {'/', '@', '&', '*'}


def describe_ids_by_crypta_id(ids, ttl):
    tvm_client = tvmauth.TvmClient(
        tvmauth.TvmApiClientSettings(
            self_tvm_id=conf.proto.Tvm.SourceTvmId,
            self_secret=conf.proto.Tvm.Secret,
            dsts={'siberia': conf.proto.Siberia.Tvm.DestinationTvmId},
            localhost_port=tvm_helpers.get_tvm_test_port(),
        )
    )
    tvm_ticket = tvm_client.get_service_ticket_for('siberia')

    siberia = SiberiaClient(conf.proto.Siberia.Host, conf.proto.Siberia.Port)

    return siberia.user_sets_describe_ids(
        ids,
        experiment=describing_experiment_pb2.TDescribingExperiment(CryptaIdUserDataVersion='by_crypta_id'),
        ttl=ttl,
        tvm_ticket=tvm_ticket
    ).UserSetId


def subsample_info_schema():
    return schema_utils.yt_schema_from_dict({
        GROUP_NAME_ID: 'string',
        SAMPLE_SIZE: 'uint64',
        '_yql_column_0': 'string',
    })


def subsample_info_attrs():
    return {
        'schema': subsample_info_schema(),
        '_yql_row_spec': {
            'StrictSchema': True,
            'Type': [
                'StructType',
                [
                    [GROUP_NAME_ID, ['DataType', 'String']],
                    [SAMPLE_SIZE, ['DataType', 'Uint64']]],
            ],
        }
    }


def clear_group_name(name):
    for symbol in BANNED_SYMBOLS:
        name = name.replace(symbol, '-')
    return name.replace('\\', '')


# TODO(unretrofied): rework api usage for further tests and proper inheritance in subclasses
def get_api():
    return swagger(conf.api.url, conf.api.oauth)


def get_sample(api, sample_id):
    try:
        return api.lab.getSample(id=sample_id).result()
    except Exception as e:
        if getattr(e, 'status_code', None) == 404:
            logger.warning('No such sample %s', sample_id)
        else:
            logger.error('API error while getting sample %s: %s', sample_id, str(e))
        raise e


def get_view(api, sample_id, view_id):
    try:
        return api.lab.getSampleView(id=sample_id, view_id=view_id).result()
    except Exception as e:
        if getattr(e, 'status_code', None) == 404:
            logger.warning('No such view %s for sample %s', view_id, sample_id)
        else:
            logger.error('API error while getting view %s of sample %s: %s', view_id, sample_id, str(e))
        raise e


class GetSubsamplesInfo(tasks.YQLTaskV1, workflow.IndependentTask, utils.WithApi):
    sample_id = workflow.Parameter()
    src_view_id = workflow.Parameter()
    max_groups_count = workflow.Parameter()

    @cached_property
    def sample(self):
        return self.api.lab.getSample(id=self.sample_id).result()

    @cached_property
    def src_view(self):
        return get_view(self.api, self.sample_id, self.src_view_id)

    @cached_property
    def _attribute(self):
        return '_day'

    @cached_property
    def _today(self):
        return time_utils.get_current_moscow_datetime().date().isoformat()

    @cached_property
    def subsamples_info_output_table(self):
        path = yt.TablePath(
            yt.ypath_join(conf.paths.lab.samples, self.sample_id, SUBSAMPLES_INFO_TABLE_NAME),
            attributes=subsample_info_attrs(),
        )

        return path

    @cached_property
    def query(self):
        return templater.render_template(
            rs.find('/crypta/lab/subsamples_info.yql'),
            dict(
                subsample_info_dst=self.subsamples_info_output_table,
                src_view=self.src_view.Path,
                id_column=self.sample.idKey,
                grouping_key=self.sample.groupingKey,
                max_groups_count=self.max_groups_count,
            )
        )

    def run(self, **kwargs):
        super(GetSubsamplesInfo, self).run(**kwargs)
        self.yt.run_sort(self.subsamples_info_output_table, sort_by=GROUP_NAME_ID)
        self.yt.set_attribute(self.subsamples_info_output_table, self._attribute, self._today)

    def targets(self):
        yield target_table.HasAttribute(self.yt, self.subsamples_info_output_table, self._attribute, self._today)


class PrepareSubsamples(tasks.YtTask, utils.WithApi):
    sample_id = workflow.Parameter()
    src_view_id = workflow.Parameter()
    max_groups_count = workflow.Parameter()

    def requires(self):
        yield GetSubsamplesInfo(
            sample_id=self.sample_id,
            src_view_id=self.src_view_id,
            max_groups_count=self.max_groups_count,
        )

    def get_sample(self):
        return self.api.lab.getSample(id=self.sample_id).result()

    @cached_property
    def subsamples_dir(self):
        path = yt.ypath_join(conf.paths.lab.samples, self.sample_id, SUBSAMPLE_NODE_NAME)
        if not self.yt.exists(path):
            self.yt.create('map_node', path)

        return path

    @cached_property
    def subsamples_info_dst(self):
        return yt.YPath(yt.ypath_join(conf.paths.lab.samples, self.sample_id, SUBSAMPLES_INFO_TABLE_NAME))

    @cached_property
    def subsamples_info(self):
        return [sample_pb2.SubsampleInfo(group_name=clear_group_name(row[GROUP_NAME_ID]), sample_size=row[SAMPLE_SIZE]) for row in self.yt.read_table(self.subsamples_info_dst)]

    @cached_property
    def _attribute_ready(self):
        return '_ready'

    def run(self, **kwargs):
        sample = self.get_sample()

        sample_table = yt.ypath_join(conf.paths.lab.samples, sample.id, self.src_view_id)
        self.yt.run_sort(sample_table, sort_by=sample.groupingKey)

        sample_table = yt.TablePath(
            sample_table,
            rename_columns={sample.groupingKey: GROUP_NAME_ID},
        )
        subsamples_tables = [yt.ypath_join(self.subsamples_dir, subsample.group_name) for subsample in self.subsamples_info]

        subsamples_info_table = utils._foreign(self.subsamples_info_dst)
        logger.info('Splitting original sample into subsamples...')

        state = OtherProto.TGroupNameState(GroupName=GROUP_NAME_ID)

        self.native_join_reduce(
            TSubsamplesJoinReducer,
            source=[subsamples_info_table, sample_table],
            destination=subsamples_tables,
            join_by=[GROUP_NAME_ID],
            spec=YtSpecs.JOIN_REDUCE_HEAVY_JOBS,
            state=state.SerializeToString(),
        )
        logger.info('Splitting finished with subsamples: %s', str(subsamples_tables))

        logger.info('Sampling large tables to fit Siberia ids limit...')

        for subsample in self.subsamples_info:
            subsample_table = yt.ypath_join(self.subsamples_dir, subsample.group_name)

            if subsample.sample_size > MAX_DESCRIBING_IDS_COUNT:
                self.yt.sample_rows_from_table(subsample_table, subsample_table, MAX_DESCRIBING_IDS_COUNT)
            self.yt.set_attribute(subsample_table, self._attribute_ready, True)

    def targets(self):
        path = yt.ypath_join(conf.paths.lab.samples, self.sample_id, SUBSAMPLE_NODE_NAME)
        subsamples_dir_exists = yt.exists(path)
        yield target_table.Exists(self.yt, path)

        if subsamples_dir_exists:
            for subsample in self.subsamples_info:
                yield target_table.HasAttribute(
                    self.yt, yt.ypath_join(self.subsamples_dir, subsample.group_name), self._attribute_ready, True
                )


class DescribeSubsamples(tasks.YtTask, utils.WithApi):
    sample_id = workflow.Parameter()
    src_view_id = workflow.Parameter()
    max_groups_count = workflow.Parameter()

    def requires(self):
        yield PrepareSubsamples(
            sample_id=self.sample_id,
            src_view_id=self.src_view_id,
            max_groups_count=self.max_groups_count,
        )

    @cached_property
    def subsamples_dir(self):
        return yt.ypath_join(conf.paths.lab.samples, str(self.sample_id), SUBSAMPLE_NODE_NAME)

    @cached_property
    def subsamples_info(self):
        path = yt.ypath_join(conf.paths.lab.samples, str(self.sample_id), SUBSAMPLES_INFO_TABLE_NAME)

        return [sample_pb2.SubsampleInfo(group_name=clear_group_name(row[GROUP_NAME_ID]), sample_size=row[SAMPLE_SIZE]) for row in self.yt.read_table(path)]

    @cached_property
    def sample(self):
        return get_sample(self.api, self.sample_id)

    @cached_property
    def src_view(self):
        return get_view(self.api, self.sample_id, self.src_view_id)

    @cached_property
    def attribute_described(self):
        return '_described'

    def targets(self):
        path = yt.ypath_join(conf.paths.lab.samples, self.sample_id, SUBSAMPLE_NODE_NAME)

        if self.yt.exists(path):
            for subsample_table in self.yt.list(yt.ypath_join(self.subsamples_dir), absolute=True):
                yield target_table.HasAttribute(self.yt, subsample_table, self.attribute_described, True)

    def run(self, **kwargs):
        logger.info('Describing %s of %s', self.src_view, self.sample_id)

        subsamples = self.subsamples_info
        user_sets = dict()

        for subsample in subsamples:
            subsample_table = yt.ypath_join(self.subsamples_dir, subsample.group_name)

            ids = IdProto.TIds(
                Ids=[
                    IdProto.TId(Type=self.sample.idName, Value=str(row[self.sample.idKey]))
                    for row in self.yt.read_table(subsample_table)
                ]
            )

            user_set_id = describe_ids_by_crypta_id(ids, ttl=self.sample.ttl)

            user_sets[subsample.group_name] = user_set_id

        # TODO properly rename args here and in the API
        self.api.lab.createSubsamples(id=self.sample_id, ids=[','.join((v, k)) for k, v in user_sets.items()]).result()

        self.yt.set_attribute(yt.ypath_join(self.subsamples_dir), self.attribute_described, True)

        for subsample_table in self.yt.list(yt.ypath_join(self.subsamples_dir), absolute=True):
            self.yt.set_attribute(subsample_table, self.attribute_described, True)

        self.yt.set_attribute(yt.ypath_join(conf.paths.lab.samples, self.sample_id, self.src_view_id), '_described_old', True)


# TODO(unretrofied): use old name of description task (Describe) due to backward copmpatibility, to be deleted after tests
class Describe(tasks.YtTask, utils.WithApi):
    sample_id = workflow.Parameter()
    src_view_id = workflow.Parameter()
    max_groups_count = workflow.Parameter()

    def requires(self):
        yield DescribeSubsamples(
            sample_id=self.sample_id,
            src_view_id=self.src_view_id,
            max_groups_count=self.max_groups_count,
        )

    def targets(self):
        path = yt.ypath_join(conf.paths.lab.samples, self.sample_id, self.src_view_id)

        if self.yt.exists(path):
            yield target_table.HasAttribute(self.yt, path, '_described_old', True)

    def run(self, **kwargs):
        pass


class DescribeSingleSample(workflow.IndependentTask, tasks.YtTask, utils.WithApi):
    sample_id = workflow.Parameter()
    src_view_id = workflow.Parameter()

    @cached_property
    def src_view(self):
        return get_view(self.api, self.sample_id, self.src_view_id)

    @cached_property
    def sample(self):
        return get_sample(self.api, self.sample_id)

    @cached_property
    def attribute_described(self):
        return '_described'

    def targets(self):
        yield target_table.HasAttribute(self.yt, self.src_view.Path, self.attribute_described, True)

    def get_ids_from_table(self, table):
        return IdProto.TIds(
            Ids=[
                IdProto.TId(Type=self.sample.idName, Value=str(row[self.sample.idKey]))
                for row in self.yt.read_table(table)
            ]
        )

    def run(self, **kwargs):
        logger.info('Describing %s of %s', self.src_view.Path, self.sample_id)

        sample_table = self.src_view.Path

        if self.yt.row_count(sample_table) <= MAX_DESCRIBING_IDS_COUNT:
            ids = self.get_ids_from_table(sample_table)
        else:
            with self.yt.TempTable(attributes=dict(schema=self.yt.get_attribute(sample_table, 'schema'))) as tmp:
                self.yt.sample_rows_from_table(sample_table, tmp, MAX_DESCRIBING_IDS_COUNT)
                ids = self.get_ids_from_table(tmp)

        user_set_id = str(describe_ids_by_crypta_id(ids, self.sample.ttl))

        self.api.lab.setUserSetId(id=self.sample_id, user_set_id=user_set_id).result()
        self.yt.set_attribute(sample_table, self.attribute_described, True)


class CreateStandardViews(workflow.IndependentTask, tasks.YtTask, utils.WithApi):
    sample_id = workflow.Parameter()
    src_view_id = workflow.Parameter()
    invalid_view_id = workflow.Parameter()
    yandexuid_view_id = workflow.Parameter()

    @cached_property
    def _today(self):
        return time_utils.get_current_moscow_datetime().date().isoformat()

    @cached_property
    def _attribute(self):
        return '_day'

    @cached_property
    def sample(self):
        return get_sample(self.api, self.sample_id)

    @cached_property
    def src_view(self):
        return get_view(self.api, self.sample_id, self.src_view_id)

    @cached_property
    def yandexuid_dst_view(self):
        return get_view(self.api, self.sample_id, self.yandexuid_view_id)

    @cached_property
    def invalid_view(self):
        return get_view(self.api, self.sample_id, self.invalid_view_id)

    def needs_matching(self, the_type):
        return the_type != 'crypta_id'

    def is_matching_key(self, item):
        return 'name' in item and item['name'] == self.yandexuid_dst_view.Options.Matching.Key

    def get_schema(self, dst_view, src_view):
        schema = None
        sample = self.sample
        yandexuid_dst_view = self.yandexuid_dst_view

        if dst_view.Options.Matching.IncludeOriginal:
            schema = self.yt.get_attribute(src_view.Path, 'schema')
        if not schema:
            schema = yt.yson.YsonList()
        if schema.attributes.get('unique_keys'):
            schema.attributes['unique_keys'] = False
            schema.attribute['strict'] = False

        for item in schema:
            for schema_attribute in ('sort_order', 'type_v2', 'required', 'group', 'type_v3'):
                if schema_attribute in item:
                    item.pop(schema_attribute)

        logger.info('Source schema %s', schema)

        if self.needs_matching(sample.idName):
            key_added = False

            for item in schema:
                if 'name' in item:
                    if item['name'] == yandexuid_dst_view.Options.Matching.Key:
                        key_added = True
                    elif item['name'] == 'GroupID':
                        item['name'] = '__' + item['name']
            if not key_added:
                schema.extend([{'name': yandexuid_dst_view.Options.Matching.Key, 'type': 'string'}])
        else:
            for item in schema:
                if 'name' in item:
                    if item['name'] == sample.idKey:
                        item['name'] = yandexuid_dst_view.Options.Matching.Key
                    elif item['name'] == 'GroupID':
                        item['name'] = '__' + item['name']

        schema.extend([{'name': 'GroupID', 'type': 'string'}])
        logger.info('Destination schema %s', schema)

        return schema

    def update_sample_view_state(self, view_id, state):
        return self.api.lab.updateSampleViewState(
            id=self.sample_id,
            view_id=view_id,
            state=state,
        ).result()

    def create_view(self, view_id):
        return self.update_sample_view_state(view_id, 'PROCESSING')

    def set_sample_view_ready(self, view_id):
        return self.update_sample_view_state(view_id, 'READY')

    def fail_sample_view(self, view_id):
        return self.update_sample_view_state(view_id, 'ERROR')

    def targets(self):
        for path in [self.yandexuid_dst_view.Path, self.invalid_view.Path]:
            yield target_table.HasAttribute(self.yt, path, self._attribute, self._today)

    def run(self, **kwargs):
        vertices = conf.paths.graph.vertices_no_multi_profile
        vertices_by_crypta_id = conf.paths.graph.vertices_by_crypta_id
        direct_users_table = conf.paths.direct.users
        src_view = self.src_view
        sample = get_sample(self.api, self.sample_id)

        yandexuid_dst_view = get_view(self.api, self.sample_id, self.yandexuid_view_id)
        logger.info('Will store yandexuids in %s', yandexuid_dst_view)
        self.create_view(self.yandexuid_view_id)

        invalid_dst_view = self.invalid_view
        logger.info('Will store invalid records in %s', invalid_dst_view)
        logger.info(self.create_view(self.invalid_view_id))

        schema = self.get_schema(yandexuid_dst_view, src_view)
        attributes = {'schema': schema}

        src_proto = _to_proto(MatchingProto.TMatchingOptions, src_view.Options.Matching)
        dst_proto = _to_proto(MatchingProto.TMatchingOptions, yandexuid_dst_view.Options.Matching)
        state = OtherProto.TSourceDestinationState(
            Source=src_proto,
            Destination=dst_proto
        )

        with self.yt.TempTable(prefix='with_id_value_') as with_id_value, \
            self.yt.TempTable(prefix='crypta_id_') as crypta_ids:  # noqa
            self.yt.create(
                'table',
                yandexuid_dst_view.Path,
                attributes=attributes,
                force=True,
            )

            if invalid_dst_view:
                self.native_map(
                    TValidateMapper,
                    source=src_view.Path,
                    destination=invalid_dst_view.Path,
                    state=state.SerializeToString(),
                )

                logger.info(self.set_sample_view_ready(self.invalid_view_id))

            if self.needs_matching(sample.idName):
                logger.info('Type [%s] needs matching', sample.idName)

                if src_view.Options.Matching.IdType == MatchingProto.ELabIdentifierType.Name(MatchingProto.LAB_ID_DIRECT_CLIENT_ID):
                    self.map_reduce(
                        _direct_join_mapper(state),
                        _direct_join_reducer(state),
                        source=[direct_users_table, src_view.Path],
                        destination=with_id_value,
                        reduce_by='ClientID',
                        sort_by=['ClientID', '__table_index']
                    )
                else:
                    self.native_map(
                        TComputeMatchingIdMapper,
                        source=src_view.Path,
                        destination=with_id_value,
                        state=state.SerializeToString(),
                    )

                if src_view.Options.Matching.IdType == MatchingProto.ELabIdentifierType.Name(MatchingProto.LAB_ID_CRYPTA_ID):
                    id_verticies = vertices_by_crypta_id
                    mapped_row_name = 'cryptaId'
                else:
                    id_verticies = vertices
                    mapped_row_name = 'id'
                join_sources = [utils._foreign(id_verticies), with_id_value]

                self.sort(
                    source=with_id_value,
                    destination=with_id_value,
                    sort_by=mapped_row_name,
                )
                self.native_join_reduce(
                    TJoinCryptaIDReducer,
                    source=join_sources,
                    destination=crypta_ids,
                    join_by=mapped_row_name,
                    state=state.SerializeToString(),
                )
                self.sort(
                    source=crypta_ids,
                    destination=crypta_ids,
                    sort_by='cryptaId',
                )
                self.native_join_reduce(
                    TJoinIdentifiersReducer,
                    source=[utils._foreign(vertices_by_crypta_id), crypta_ids],
                    destination=yandexuid_dst_view.Path,
                    join_by='cryptaId',
                    state=state.SerializeToString(),
                )
            else:
                logger.info('Type [%s] needs no matching', sample.idName)
                self.native_map(
                    TRenameIdentifierMapper,
                    source=src_view.Path,
                    destination=yandexuid_dst_view.Path,
                    state=OtherProto.TRenameIdentifierMapperState(
                        Sample=_to_proto(sample_pb2.Sample, sample),
                        Source=_to_proto(MatchingProto.TMatchingOptions, src_view.Options.Matching),
                        Destination=_to_proto(MatchingProto.TMatchingOptions, yandexuid_dst_view.Options.Matching),
                    ).SerializeToString(),
                )

            self.yt.run_merge(src_view.Path, src_view.Path, spec=dict(combine_chunks=True))

            # empty result should result in error
            if self.yt.is_empty(yandexuid_dst_view.Path):
                logger.info(self.fail_sample_view(self.yandexuid_view_id))
                return
            else:
                logger.info('Non-empty')

            logger.info(self.set_sample_view_ready(self.yandexuid_view_id))

        self.yt.set_attribute(yandexuid_dst_view.Path, self._attribute, self._today)
        self.yt.set_attribute(invalid_dst_view.Path, self._attribute, self._today)
