from datetime import datetime
from functools import partial
from itertools import islice
import logging
import os

from yt.wrapper import (
    yson,
    with_context,
)

from cached_property import cached_property
from crypta.lab.lib.common import (
    WithApi,
    _foreign
)
from crypta.lab.lib.native_operations import (
    TGenericJoinReducer,
    TTransformUserDataToUserDataStats,
    TMergeUserDataStats,
    TIdentityMapper,
    TUniqueReducer,
    TComputeMatchingIdMapper,
    TJoinCryptaIDReducer,
    TJoinIdentifiersReducer,
    TRenameIdentifierMapper,
    TAddGroupingKey,
    TValidateMapper,
)
from crypta.lab.lib.specs import (
    _spec,
    YtSpecs,
)
from crypta.lab.lib.tables import (
    SampleStatsStorage,
)
from crypta.lab.proto.describe_pb2 import TDatedSample
import crypta.lab.proto.matching_pb2 as Proto
from crypta.lab.proto.matching_pb2 import TMatchingOptions
from crypta.lab.proto.other_pb2 import (
    TJoinOptions,
    TRenameIdentifierMapperState,
    TSourceDestinationState,
)
from crypta.lab.proto.sample_pb2 import (
    Sample,
    TSampleStats,
)
from crypta.lib.proto.user_data.user_data_stats_pb2 import (
    TUserDataStats,
    TUserDataStatsOptions,
)
from crypta.lib.python import (
    proto as proto_utils,
    templater,
)
import crypta.lib.python.bt.conf.conf as conf
from crypta.lib.python.bt.tasks import (
    YQLTaskV1 as YQLTask,
    YtTask,
)
from crypta.lib.python.bt.workflow import (
    IndependentTask,
    Parameter,
)
from crypta.lib.python.swagger import _to_proto
from crypta.lib.python.yt import schema_utils
from crypta.siberia.bin.common.convert_to_user_data_stats.py.native_operations import TConvertToUserDataStatsMapper
from crypta.siberia.bin.common.proto.crypta_id_user_data_pb2 import TCryptaIdUserData
from crypta.siberia.bin.common.yt_describer.proto.yt_describer_config_pb2 import TYtDescriberConfig
from crypta.siberia.bin.common.yt_describer.py import describe

logger = logging.getLogger(__name__)


def _batched(iterable, size):
    iterator = iter(iterable)
    while iterator:
        batch = list(islice(iterator, size))
        if not batch:
            break
        yield batch


def _direct_join_mapper(state):
    @with_context
    def _mapper(record, context):
        if context.table_index == 0:
            yield {'ClientID': record['ClientID'], 'login': record['login'], '__table_index': 0}
        elif context.table_index == 1:
            client_id = record[state.Source.Key]
            if state.Destination.IncludeOriginal:
                record['ClientID'] = client_id
                record['__table_index'] = 1
                yield record
            else:
                yield {'ClientID': client_id, '__table_index': 1}
        else:
            raise Exception('Invalid state')

    return _mapper


def _direct_join_reducer(state):
    def _reducer(key, records):
        login = None
        for record in records:
            table_index = record.pop('__table_index')
            if table_index == 0:
                login = record['login'].lower().replace('.', '-')
            if login and table_index == 1:
                if state.Destination.IncludeOriginal:
                    record['id'] = login
                    if state.Source.Key != 'ClientID':
                        record.pop('ClientID')
                    yield record
                else:
                    yield {'id': login}

    return _reducer


class Describe(YtTask, IndependentTask, WithApi):
    sample_id = Parameter()
    view_id = Parameter()
    yandexuid_view_id = Parameter()
    invalid_view_id = Parameter()

    @cached_property
    def sample_stats_storage(self):
        storage = SampleStatsStorage(self._init_yt())
        storage.prepare_table()
        return storage

    def needs_matching(self, the_type):
        if the_type in ('yuid', 'yandexuid'):
            return False
        return True

    def get_sample(self):
        try:
            sample = self.api.lab.getSample(id=self.sample_id).result()
            return sample
        except Exception as e:
            if getattr(e, 'status_code', None) == 404:
                return None
            raise e

    def get_view(self, view_id):
        try:
            view = self.api.lab.getSampleView(id=self.sample_id, view_id=view_id).result()
            return view
        except Exception as e:
            if getattr(e, 'status_code', None) == 404:
                return None
            raise e

    def get_schema(self, dst_view, src_view):
        schema = None
        if dst_view.Options.Matching.IncludeOriginal:
            schema = self.yt.get_attribute(src_view.Path, 'schema')
        if not schema:
            schema = yson.YsonList()
        if schema.attributes.get('unique_keys'):
            schema.attributes['unique_keys'] = False
        for item in schema:
            if 'sort_order' in item:
                item.pop('sort_order')
            if 'type_v2' in item:
                item.pop('type_v2')
            if 'required' in item:
                item.pop('required')
            if 'group' in item:
                item.pop('group')
            if 'type_v3' in item:
                item.pop('type_v3')

        return schema

    def run(self, **kwargs):
        vertices = conf.paths.graph.vertices_no_multi_profile
        vertices_by_crypta_id = conf.paths.graph.vertices_by_crypta_id
        userdata_table = conf.paths.lab.data.userdata
        direct_users_table = conf.paths.direct.users

        sample = self.get_sample()
        if not sample:
            logger.warning('No such sample %s', self.sample_id)
            return

        src_view = self.get_view(self.view_id)
        if not src_view:
            logger.warning('No such view %s of %s', self.view_id, self.sample_id)
            return

        logger.info('Describing %s of %s', src_view, sample)

        dst_view = self.get_view(self.yandexuid_view_id)
        logger.info('Will store yandexuids in %s', dst_view)

        invalid_dst_view = self.get_view(self.invalid_view_id)
        logger.info('Will store invalid records in %s', invalid_dst_view)

        logger.info(self.api.lab.updateSampleViewState(
            id=self.sample_id,
            view_id=self.yandexuid_view_id,
            state='PROCESSING',
        ).result())

        logger.info(self.api.lab.updateSampleViewState(
            id=self.sample_id,
            view_id=self.invalid_view_id,
            state='PROCESSING',
        ).result())

        schema = self.get_schema(dst_view, src_view)
        logger.info('Source schema %s', schema)
        if self.needs_matching(sample.idName):
            key_added = False
            for item in schema:
                if 'name' in item and item['name'] == dst_view.Options.Matching.Key:
                    item['name'] = dst_view.Options.Matching.Key
                    item['type'] = 'string'
                    key_added = True
                elif 'name' in item and (item['name'] == dst_view.Options.Matching.Key or item['name'] == 'GroupID'):
                    item['name'] = '__' + item['name']
            if not key_added:
                schema.extend([{'name': dst_view.Options.Matching.Key, 'type': 'string'}])
        else:
            for item in schema:
                if 'name' in item and item['name'] == sample.idKey:
                    item['name'] = dst_view.Options.Matching.Key
                    item['type'] = 'string'
                elif 'name' in item and (item['name'] == dst_view.Options.Matching.Key or item['name'] == 'GroupID'):
                    item['name'] = '__' + item['name']
        schema.extend([{'name': 'GroupID', 'type': 'string'}])
        logger.info('Destination schema %s', schema)
        attributes = {'schema': schema}

        src_proto = _to_proto(TMatchingOptions, src_view.Options.Matching)
        dst_proto = _to_proto(TMatchingOptions, dst_view.Options.Matching)
        state = TSourceDestinationState(
            Source=src_proto,
            Destination=dst_proto
        )
        with self.yt.TempTable(prefix='with_id_value_') as with_id_value, \
                self.yt.TempTable(prefix='crypta_id_') as crypta_ids, \
                self.yt.TempTable(prefix='with_userdata_') as with_userdata, \
                self.yt.TempTable(prefix='uniqs_per_group_') as uniqs_per_group, \
                self.yt.TempTable(prefix='stats_') as stats:

            self.yt.create(
                'table',
                dst_view.Path,
                attributes=attributes,
                force=True,
            )

            if invalid_dst_view:
                self.native_map(
                    TValidateMapper,
                    source=src_view.Path,
                    destination=invalid_dst_view.Path,
                    state=state.SerializeToString(),
                )

                logger.info(self.api.lab.updateSampleViewState(
                    id=self.sample_id,
                    view_id=self.invalid_view_id,
                    state='READY',
                ).result())

            if self.needs_matching(sample.idName):
                logger.info('Type [%s] needs matching', sample.idName)
                if src_view.Options.Matching.IdType == Proto.ELabIdentifierType.Name(Proto.LAB_ID_DIRECT_CLIENT_ID):
                    self.map_reduce(
                        _direct_join_mapper(state),
                        _direct_join_reducer(state),
                        source=[direct_users_table, src_view.Path],
                        destination=with_id_value,
                        reduce_by='ClientID',
                        sort_by=['ClientID', '__table_index']
                    )
                else:
                    self.native_map(
                        TComputeMatchingIdMapper,
                        source=src_view.Path,
                        destination=with_id_value,
                        state=state.SerializeToString(),
                    )

                if src_view.Options.Matching.IdType == Proto.ELabIdentifierType.Name(Proto.LAB_ID_CRYPTA_ID):
                    join_sources = [_foreign(vertices_by_crypta_id), with_id_value]
                    mappedRowName = 'cryptaId'
                else:
                    join_sources = [_foreign(vertices), with_id_value]
                    mappedRowName = 'id'
                self.sort(
                    source=with_id_value,
                    destination=with_id_value,
                    sort_by=mappedRowName,
                )
                self.native_join_reduce(
                    TJoinCryptaIDReducer,
                    source=join_sources,
                    destination=crypta_ids,
                    join_by=mappedRowName,
                    state=state.SerializeToString(),
                )
                self.sort(
                    source=crypta_ids,
                    destination=crypta_ids,
                    sort_by='cryptaId',
                )
                self.native_join_reduce(
                    TJoinIdentifiersReducer,
                    source=[_foreign(vertices_by_crypta_id), crypta_ids],
                    destination=dst_view.Path,
                    join_by='cryptaId',
                    state=state.SerializeToString(),
                )
                self.native_map(
                    TAddGroupingKey,
                    source=dst_view.Path,
                    destination=dst_view.Path,
                    state=TRenameIdentifierMapperState(
                        Sample=_to_proto(Sample, sample),
                        Source=_to_proto(TMatchingOptions, src_view.Options.Matching),
                        Destination=_to_proto(TMatchingOptions, dst_view.Options.Matching),
                    ).SerializeToString(),
                )
            else:
                logger.info('Type [%s] needs no matching', sample.idName)
                self.native_map(
                    TRenameIdentifierMapper,
                    source=src_view.Path,
                    destination=dst_view.Path,
                    state=TRenameIdentifierMapperState(
                        Sample=_to_proto(Sample, sample),
                        Source=_to_proto(TMatchingOptions, src_view.Options.Matching),
                        Destination=_to_proto(TMatchingOptions, dst_view.Options.Matching),
                    ).SerializeToString(),
                )

            self.yt.run_merge(src_view.Path, src_view.Path, spec=dict(combine_chunks=True))

            # empty result should result in error
            if self.yt.is_empty(dst_view.Path):
                logger.info(self.api.lab.updateSampleViewState(
                    id=self.sample_id,
                    view_id=self.yandexuid_view_id,
                    state='ERROR',
                ).result())
                return
            else:
                logger.info('Non-empty')

            self.yt.run_merge(dst_view.Path, dst_view.Path, spec=dict(combine_chunks=True))

            self.native_map_reduce(
                TIdentityMapper,
                TUniqueReducer,
                source=dst_view.Path,
                destination=uniqs_per_group,
                reduce_by=['GroupID', 'yuid'],
            )
            self.sort(
                source=uniqs_per_group,
                destination=uniqs_per_group,
                sort_by='yuid',
            )
            self.native_join_reduce(
                TGenericJoinReducer,
                source=[_foreign(userdata_table), uniqs_per_group],
                destination=with_userdata,
                join_by='yuid',
                state=TJoinOptions(
                    ForeignTableMaxRecords=1,
                    Mode='M_KEEP_BOTH',
                ).SerializeToString(),
            )
            self.native_map_reduce_with_combiner(
                TTransformUserDataToUserDataStats,
                TMergeUserDataStats,
                TMergeUserDataStats,
                with_userdata,
                stats,
                reduce_by='GroupID',
                mapper_state=TUserDataStatsOptions(
                    Flags=TUserDataStatsOptions.TFlags(
                        DuplicateWithoutGroupID=True,
                        IgnoreDistributions=True,
                    )
                ).SerializeToString(),
                spec=_spec(
                    YtSpecs.REDUCER_HIGH_MEMORY_USAGE,
                    YtSpecs.REDUCER_BIG_ROWS,
                    YtSpecs.MAPPER_BIG_ROWS,
                    YtSpecs.MAPPER_VERY_HIGH_MEMORY_USAGE,
                    YtSpecs.NO_INTERMEDIATE_COMPRESSION,
                ),
            )

            logger.info('Cleaning sample_id: %s', self.sample_id)
            self.sample_stats_storage.delete_rows(
                self.sample_stats_storage.select_rows(
                    'SampleID, GroupID', 'SampleID=\'{}\''.format(self.sample_id)
                )
            )

            logger.info('Pushing stats into the dynamic table')
            for batch in _batched(self.yt.read_table(stats), conf.describe.stats_batch_size):
                self.sample_stats_storage.insert_rows([dict(
                    SampleID=self.sample_id,
                    GroupID=each.get('GroupID') or None,
                    Stats=proto_utils.read_proto_from_dict(TUserDataStats(), each).SerializeToString(),
                ) for each in batch])

            logger.info(self.api.lab.updateSampleViewState(
                id=self.sample_id,
                view_id=self.yandexuid_view_id,
                state='READY',
            ).result())


class PastDescribe(YtTask, IndependentTask, WithApi, YQLTask):
    sample_id = Parameter()
    view_id = Parameter()

    CRYPTA_ID = 'crypta_id'
    DATE = 'date'
    ID = 'id'
    STATS = 'Stats'

    DATE_FORMAT = '%Y-%m-%d'

    @cached_property
    def src_with_dates_schema(self):
        return schema_utils.get_schema_from_proto(TDatedSample)

    @cached_property
    def with_user_data_stats_schema(self):
        return schema_utils.get_schema_from_proto(TCryptaIdUserData)

    @cached_property
    def segment_stats_schema(self):
        return schema_utils.get_schema_from_proto(TSampleStats)

    @cached_property
    def sample_stats_storage(self):
        storage = SampleStatsStorage(self._init_yt())
        storage.prepare_table()
        return storage

    @property
    def userdata(self):
        return conf.paths.lab.data.crypta_id.userdata_daily

    @property
    def query(self):
        raise NotImplementedError('query property is not supported by this class')

    def get_sample(self):
        try:
            sample = self.api.lab.getSample(id=self.sample_id).result()
            return sample
        except Exception as e:
            if getattr(e, 'status_code', None) == 404:
                return None
            raise e

    def get_view(self, view_id):
        try:
            view = self.api.lab.getSampleView(id=self.sample_id, view_id=view_id).result()
            return view
        except Exception as e:
            if getattr(e, 'status_code', None) == 404:
                return None
            raise e

    def update_sample_view_state(self, id, view_id, state):
        # TODO(evgeniia-r): CRYPTA-16023
        self.update_sample_state(id=self.sample_id, state=state)
        return self.api.lab.updateSampleViewState(id=id, view_id=view_id, state=state).result()

    # TODO(evgeniia-r): CRYPTA-16023
    def update_sample_state(self, id, state):
        return self.api.lab.updateSampleState(id=id, state=state).result()

    # TODO(evgeniia-r): CRYPTA-16023
    def set_sample_view_error(self, id, view_id, error):
        self.update_sample_state(id=id, state='ERROR')
        return self.api.lab.updateSampleViewError(id=id, view_id=view_id, error=error).result()

    def match_view_by_date(self, src_with_dates, src_by_crypta_id, id_to_crypta_id, src_for_describe, src_id_type):
        ids_list = ['idfa', 'gaid'] if src_id_type == 'idfa_gaid' else [src_id_type]
        id_to_crypta_id_tables = [os.path.join(conf.paths.matching.root, id, 'crypta_id') for id in ids_list]
        query = templater.render_resource(
            '/crypta/lab/matching_by_date.yql',
            strict=True,
            vars={
                'sample_table': src_with_dates,
                'src_id_type': src_id_type,
                'id_to_crypta_id_tables': id_to_crypta_id_tables,
                'view_by_crypta_id': src_by_crypta_id,
                'id_to_crypta_id_table': id_to_crypta_id,
                'describe_sample': src_for_describe,
                'users_to_leave': conf.describe_by_dates.users_num,
            },
        )
        self.yql_client.execute(
            query=query,
            transaction=str(self.transaction_id),
            title='YQL Match yandexuid to crypta_id for past describe',
        )

    def get_distinct_dates(self, src_with_dates):
        with self.yt.TempTable(prefix='usedDates') as used_dates_table:
            query = templater.render_resource(
                '/crypta/lab/distinct_dates.yql',
                strict=True,
                vars={
                    'date_field': self.DATE,
                    'src_with_dates': src_with_dates,
                    'used_dates_table': used_dates_table,
                },
            )
            self.yql_client.execute(
                query=query,
                transaction=str(self.transaction_id),
                title='YQL Get distinct dates',
            )
            used_dates = [
                row[self.DATE] for row in self.yt.read_table(self.yt.TablePath(used_dates_table, columns=[self.DATE]))
            ]

            if not len(used_dates):
                logger.info(
                    self.set_sample_view_error(id=self.sample_id, view_id=self.view_id, error='TOO_OLD_DATES')
                )
                raise RuntimeError('No matching user_data for any dates')

            logger.info('Used UserData dates: %s', ' '.join(map(str, used_dates)))
            return used_dates

    def match_src_with_userdata(self, src_by_crypta_id, with_user_data, userdata_dates):
        match_with_userdata_query = templater.render_resource(
            '/crypta/lab/matching_with_userdata.yql',
            strict=True,
            vars={
                'userdata_dates': userdata_dates,
                'src_view': src_by_crypta_id,
                'userdata_dir': self.userdata,
                'userdata_joined': with_user_data,
            },
        )
        self.yql_client.execute(
            query=match_with_userdata_query,
            transaction=str(self.transaction_id),
            title='YQL Match view to daily UserData',
        )

    @staticmethod
    def update_dates_mapper(row, userdata_dates, src_date_key, src_id_key, date_format, dst_id_key, dst_date_key, max_diff):
        try:
            row_date = datetime.strptime(row[src_date_key], date_format)
        except Exception as e:
            logger.error(e)
            return

        if row_date < userdata_dates[0]:
            return
        closest = min(
            userdata_dates,
            key=lambda userdata_date:
            ((row_date - userdata_date).days if userdata_date < row_date else max_diff),
        ).strftime(date_format)

        yield {
            dst_id_key: str(row[src_id_key]),
            dst_date_key: closest,
        }

    def get_available_userdata_dates(self):
        userdata_dates = map(
            lambda userdata_date: datetime.strptime(userdata_date, self.DATE_FORMAT),
            self.yt.list(self.userdata),
        )
        userdata_dates_filtered_by_day = list(filter(
            lambda date: date.weekday() == conf.describe_by_dates.weekday_for_userdata,
            userdata_dates,
        ))

        logger.info('Available UserData dates: %s', ' '.join(map(str, userdata_dates_filtered_by_day)))

        return userdata_dates_filtered_by_day

    def describe_by_combined_userdata(self, with_user_data, id_to_crypta_id_table, src_for_describe):
        with_userdata_stats_attrs = {'schema': self.with_user_data_stats_schema, 'optimize_for': 'scan'}
        segment_stats_attrs = {'schema': self.segment_stats_schema, 'optimize_for': 'scan'}

        with self.yt.TempTable(prefix='withUserDataStats', attributes=with_userdata_stats_attrs) as with_user_data_stats, \
                self.yt.TempTable(prefix='segmentStats', attributes=segment_stats_attrs) as segment_stats:
            self.native_map(
                mapper_name=TConvertToUserDataStatsMapper,
                source=with_user_data,
                destination=with_user_data_stats,
            )
            self.sort(with_user_data_stats, with_user_data_stats, sort_by=[self.CRYPTA_ID])

            description_config = TYtDescriberConfig(
                IdToCryptaIdTable=id_to_crypta_id_table,
                CryptaIdUserDataTable=with_user_data_stats,
                TmpDir='//tmp',
                InputTable=src_for_describe,
                OutputTable=segment_stats,
            )

            describe(self.yt, self, description_config)

            assert self.yt.get_attribute(segment_stats, 'row_count') == 1

            userdata_stats = self.yt.read_table(self.yt.TablePath(segment_stats, columns=[self.STATS])).next()
            return userdata_stats[self.STATS]

    def run(self, **kwargs):
        sample = self.get_sample()
        src_view = self.get_view(self.view_id)
        src_path = src_view.Path
        src_id_type = sample.idName
        src_id_key = sample.idKey
        src_date_key = sample.dateKey

        logger.info('Describing %s of dated sample %s', src_view, sample)

        logger.info(self.update_sample_view_state(id=self.sample_id, view_id=self.view_id, state='PROCESSING'))

        src_with_dates_attrs = {'schema': self.src_with_dates_schema, 'optimize_for': 'scan'}
        with self.yt.TempTable(prefix='srcWithDates', attributes=src_with_dates_attrs) as src_with_dates, \
                self.yt.TempTable(prefix='srcByCryptaId') as src_by_crypta_id, \
                self.yt.TempTable(prefix='srcForDescribe') as src_for_describe, \
                self.yt.TempTable(prefix='idToCryptaId') as id_to_crypta_id_table, \
                self.yt.TempTable(prefix='withUserData') as with_user_data:

            user_data_dates = self.get_available_userdata_dates()
            self.map(
                partial(
                    self.update_dates_mapper,
                    userdata_dates=user_data_dates,
                    src_date_key=src_date_key,
                    src_id_key=src_id_key,
                    date_format=self.DATE_FORMAT,
                    dst_id_key=self.ID,
                    dst_date_key=self.DATE,
                    max_diff=conf.describe_by_dates.max_dates_diff,
                ),
                source=src_path,
                destination=src_with_dates,
            )

            if self.yt.get_attribute(src_with_dates, 'row_count') < conf.describe_by_dates.min_sample_size:
                self.set_sample_view_error(id=self.sample_id, view_id=self.view_id, error='WRONG_DATES')
                raise RuntimeError('Not enough rows with available userdata for the date or ')

            logger.info('Size of sample with updated dates: %i', self.yt.get_attribute(src_with_dates, 'row_count'))

            self.match_view_by_date(
                src_with_dates=src_with_dates,
                src_by_crypta_id=src_by_crypta_id,
                id_to_crypta_id=id_to_crypta_id_table,
                src_for_describe=src_for_describe,
                src_id_type=src_id_type,
            )
            used_dates = self.get_distinct_dates(src_by_crypta_id)

            self.yt.set_attribute(src_path, 'min_used_dates', min(used_dates))

            self.match_src_with_userdata(
                src_by_crypta_id=src_by_crypta_id,
                with_user_data=with_user_data,
                userdata_dates=used_dates,
            )

            if self.yt.get_attribute(with_user_data, 'row_count') < conf.describe_by_dates.min_sample_size:
                self.set_sample_view_error(id=self.sample_id, view_id=self.view_id, error='NOT_MATCHED')
                raise RuntimeError('Not enough ids with userdata')

            logger.info('Number of ids matched with userdata: %i', self.yt.get_attribute(with_user_data, 'row_count'))

            userdata_stats = self.describe_by_combined_userdata(with_user_data, id_to_crypta_id_table, src_for_describe)

            logger.info('Deleting old stats for sample')

            self.sample_stats_storage.delete_rows(
                self.sample_stats_storage.select_rows(
                    'SampleID, GroupID', 'SampleID=\'{}\''.format(self.sample_id)
                )
            )

            logger.info('Pushing stats to dynamic table')
            self.sample_stats_storage.insert_rows([dict(SampleID=self.sample_id, GroupID=None, Stats=userdata_stats)])

        logger.info(self.update_sample_view_state(id=self.sample_id, view_id=self.view_id, state='READY'))
