from library.python import resource
import luigi
import yt.yson as yson
from yt.wrapper import with_context, create_table_switch

from crypta.lib.python import templater
from crypta.lib.python.yt import schema_utils
from crypta.profile.utils.config import config
from crypta.profile.utils.loggers import TimeTracker
from crypta.profile.utils.luigi_utils import BaseYtTask, YtDailyRewritableTarget
from crypta.profile.utils.segment_storage import ProfileBuilder
from crypta.profile.utils.utils import (
    is_valid_uint64,
    segment_storage_by_yandexuid_schema,
    segment_storage_by_crypta_id_schema,
    segment_storage_by_id_with_crypta_id_schema,
)

from crypta.profile.tasks.features.get_crypta_ids import GetCryptaIds
from crypta.profile.runners.export_profiles.lib.profiles_generation.get_segment_storage_by_id import GetDailySegmentsById


class ExpandSegmentsStorage(BaseYtTask):
    date = luigi.Parameter()
    priority = 100
    task_group = 'export_profiles'

    def requires(self):
        return {
            'segment_storage': GetDailySegmentsById(self.date, 'crossdevice'),
            'matching': GetCryptaIds(self.date),
        }

    def output(self):
        return {
            'storage_with_crypta_id': YtDailyRewritableTarget(
                config.SEGMENTS_STORAGE_BY_ID_WITH_CRYPTA_ID_TABLE,
                self.date,
            ),
            'crypta_id': YtDailyRewritableTarget(
                config.EXPANDED_SEGMENTS_STORAGE_BY_CRYPTA_ID_TABLE,
                self.date,
            ),
            'yandexuid': YtDailyRewritableTarget(
                config.EXPANDED_SEGMENTS_STORAGE_BY_YANDEXUID_TABLE,
                self.date,
            ),
        }

    def run(self):
        with TimeTracker(monitoring_name=self.__class__.__name__):
            with self.yt.Transaction() as tx:
                for key, schema in (
                    ('crypta_id', schema_utils.yt_schema_from_dict(segment_storage_by_crypta_id_schema, sort_by=['crypta_id'])),
                    ('yandexuid', schema_utils.yt_schema_from_dict(segment_storage_by_yandexuid_schema, sort_by=['yandexuid'])),
                    ('storage_with_crypta_id', schema_utils.yt_schema_from_dict(segment_storage_by_id_with_crypta_id_schema, sort_by=['crypta_id'])),
                ):
                    self.yt.create_empty_table(self.output()[key].table, schema=schema)

                query = templater.render_template(
                    resource.find('/query/expand_segments_storage.yql'),
                    vars={
                        'merged': self.input()['segment_storage'].table,
                        'vertices_no_multi_profile': config.VERTICES_NO_MULTI_PROFILE,
                        'cryptaid_yandexuid': config.CRYPTAID_YANDEXUID_TABLE,

                        'storage': self.output()['storage_with_crypta_id'].table,
                        'expanded_cryptaid': self.output()['crypta_id'].table,
                        'expanded_yandexuid': self.output()['yandexuid'].table,
                    },
                    strict=True,
                )

                self.yql.query(query, tx)

                for output in self.output().values():
                    self.yt.set_attribute(
                        output.table,
                        'generate_date',
                        self.date,
                    )


def indevice_split_mapper(row):
    del row['update_time']

    if row['id_type'] == 'yandexuid':
        if is_valid_uint64(row['id']):
            row['yandexuid'] = yson.YsonUint64(row['id'])
            del row['id']
            del row['id_type']

            yield create_table_switch(0)
            yield row
    else:
        yield create_table_switch(1)
        yield row


@with_context
def indevice_join_with_yandexuid(key, rows, context):
    storage_row = None
    good_yandexuids = set()

    for row in rows:
        if context.table_index == 0:
            storage_row = row
        elif context.table_index == 1 and not storage_row:
            return
        else:
            good_yandexuids.add(row['yandexuid'])

    if storage_row is not None and good_yandexuids:
        del storage_row['id']
        del storage_row['id_type']

        for yandexuid in good_yandexuids:
            storage_row['yandexuid'] = yandexuid
            yield storage_row


class ExpandSegmentsStorageIndevice(BaseYtTask):
    date = luigi.Parameter()
    priority = 100
    task_group = 'export_profiles'

    def requires(self):
        return {
            'segment_storage': GetDailySegmentsById(self.date, 'indevice'),
            'matching': GetCryptaIds(self.date),
        }

    def output(self):
        return YtDailyRewritableTarget(
            config.EXPANDED_INDEVICE_SEGMENTS_STORAGE_BY_YANDEXUID_TABLE,
            self.date,
            allow_empty=True,
        )

    def run(self):
        with TimeTracker(monitoring_name=self.__class__.__name__):
            with self.yt.Transaction(), self.yt.TempTable() as table_to_expand_indevice:
                self.yt.create_empty_table(
                    self.output().table,
                    schema=segment_storage_by_yandexuid_schema,
                )

                self.yt.run_map(
                    indevice_split_mapper,
                    self.input()['segment_storage'].table,
                    [self.output().table, table_to_expand_indevice],
                )

                self.yt.run_sort(
                    table_to_expand_indevice,
                    sort_by=['id', 'id_type'],
                )

                self.yt.run_reduce(
                    indevice_join_with_yandexuid,
                    [table_to_expand_indevice,
                     config.INDEVICE_YANDEXUID],
                    self.yt.TablePath(self.output().table, append=True),
                    reduce_by=['id', 'id_type'],
                )

                self.yt.run_sort(
                    self.output().table,
                    sort_by='yandexuid',
                )

                self.yt.set_attribute(
                    self.output().table,
                    'generate_date',
                    self.date,
                )


class YandexuidProfileBuilder(ProfileBuilder):
    def __call__(self, key, rows):
        output_record = {
            'yandexuid': key['yandexuid'],
        }

        segments = self.combine_segments(rows)
        output_record.update(segments)

        yield output_record


class CombineSegmentsByYandexuid(BaseYtTask):
    date = luigi.Parameter()
    priority = 100
    task_group = 'export_profiles'

    def requires(self):
        return {
            'crypta_id_expansion': ExpandSegmentsStorage(self.date),
            'indevice_expansion': ExpandSegmentsStorageIndevice(self.date),
        }

    def output(self):
        return YtDailyRewritableTarget(
            config.SEGMENTS_STORAGE_BY_YANDEXUID_TABLE,
            self.date,
        )

    def run(self):
        with TimeTracker(monitoring_name=self.__class__.__name__):
            with self.yt.Transaction():
                self.yt.create_empty_table(
                    self.output().table,
                    schema=segment_storage_by_yandexuid_schema,
                )

                self.yt.run_reduce(
                    YandexuidProfileBuilder(),
                    [self.input()['crypta_id_expansion']['yandexuid'].table,
                     self.input()['indevice_expansion'].table],
                    self.output().table,
                    reduce_by='yandexuid',
                )

                self.yt.run_sort(
                    self.output().table,
                    sort_by='yandexuid',
                )

                self.yt.set_attribute(
                    self.output().table,
                    'generate_date',
                    self.date,
                )
