import collections
import datetime
import itertools
import logging
import os
import re
import uuid

import dateutil.parser
import tvmauth
import yt.wrapper as yt
from yt import yson
from yt.common import date_string_to_timestamp

from cached_property import cached_property
import crypta.lib.python.bt.conf.conf as conf
from crypta.audience.lib.tasks.base import (
    YQLTaskV1,
    YtTask,
)
from crypta.audience.lib.tasks.constants import (
    CRYPTAID_SOURCEID,
)
from crypta.audience.lib.affinity.affinity import (
    _compute_affinities,
)
from crypta.audience.lib.tasks.audience.sql import (
    write_segment_on_input,
    write_segment_on_output,
    write_segment_properties,
    write_segment_goal_relation,
)
from crypta.audience.lib.tasks.audience.tables import (
    Attributes,
    GeneralStorageOutput,
    Matching,
    Input,
    InputBatch,
    InvalidMatchingParameters,
    Output,
    RegularSegmentStateStorage,
    SegmentsGoalsRelationStorage,
    SegmentPropertiesStorage,
    StatsStorage,
)
from crypta.audience.lib.communality import (
    _compute_communality,
    _cosine_quantile_score,
    _raw_trace,
)
from crypta.audience.lib.tasks.math import (
    _similarity,
)
from crypta.audience.lib.tasks.native_operations import (
    TBatchify,
    TExtractGeneralStorageOutput,
)
from crypta.audience.proto import tables_pb2
from crypta.audience.proto.other_pb2 import (
    TMapping,
)
from crypta.audience.proto.states_pb2 import (
    TGeneralStorageState,
)
from crypta.lib.proto.identifiers import id_type_pb2
from crypta.lib.python import (
    templater,
    time_utils,
)
from crypta.lab.lib.tables import (
    UserData,
    UserDataStats,
)
from crypta.lib.python.bt.workflow import (
    IndependentTask,
    Parameter,
)
import crypta.lib.python.tvm.helpers as tvm_helpers
from crypta.lib.python.yt import schema_utils
import crypta.lib.python.yql.executer as yql
import crypta.siberia.bin.common.create_user_set_from_sample_reducer.py as sampler
from crypta.siberia.bin.common.describing.mode.python import describing_mode
from crypta.siberia.bin.common import sample_stats_getter
from crypta.siberia.bin.common.siberia_client import SiberiaClient


logger = logging.getLogger(__name__)

MAX_INPUT_PATHS_COUNT = 3000

WrongInput = collections.namedtuple('WrongInput', ['id', 'table', 'reason', 'status'])
BatchingKey = collections.namedtuple('BatchingKey', ['matching_table', 'target_type'])
SegmentDatabaseRecord = collections.namedtuple('SegmentDatabaseRecord', ['id', 'audience_id', 'row_count', 'time'])

COMPUTE_UNIQ_ID_VALUE_QUERY_TEMPLATE = """
PRAGMA yt.InferSchema = '1';

INSERT INTO `{destination}` WITH TRUNCATE
SELECT
    COUNT(DISTINCT {id_value_column}) as uniq_id_value_count,
    COUNT(DISTINCT yuid) as uniq_yuid_count,
    {group_by} as group_id,
FROM `{source}`
GROUP BY {group_by};
"""


def _get_list_batches_with_priorities(task, dir_path, dir_only=False, table_only=False, skip_filter=None, add_ts_priority_to_out=False):
    for batch in task.yt.list(
            dir_path,
            absolute=True,
            attributes=[InputBatch.Attributes.PRIORITY, Attributes.CREATION_TIME, Attributes.TYPE, Attributes.LOCK_COUNT],
            max_size=None):
        if batch.attributes.get(Attributes.LOCK_COUNT):
            continue
        if batch.attributes.get(Attributes.TYPE) == 'table':
            if dir_only:
                continue
        elif table_only:
            continue
        if skip_filter and skip_filter(batch):
            continue
        priority = int(batch.attributes.get(InputBatch.Attributes.PRIORITY, 0))
        date = batch.attributes.get(Attributes.CREATION_TIME)
        timestamp = date_string_to_timestamp(date)
        timestamp_priority = -int(timestamp)
        if not priority:
            priority = timestamp_priority
        if add_ts_priority_to_out:
            yield priority, timestamp_priority, batch
        else:
            yield priority, batch


def _prepare_stats(_stats, with_filter=False):
    stats = UserDataStats.Record(_stats) if not isinstance(_stats, UserDataStats.Record) else _stats

    proto_stats = UserDataStats.Proto()
    proto_stats.Attributes.CopyFrom(stats.get_attributes())
    proto_stats.Affinities.CopyFrom(stats.get_affinities())
    proto_stats.Identifiers.CopyFrom(stats.get_identifiers())
    proto_stats.Stratum.CopyFrom(stats.get_stratum())
    proto_stats.Distributions.CopyFrom(stats.get_distributions())
    trace = _raw_trace(proto_stats)
    if trace:
        proto_stats.Distributions.Precomputed.Trace = trace

    proto_stats.Distributions.Main.Covariance.Clear()
    proto_stats.Distributions.Main.Mean2.Clear()
    proto_stats.Counts.CopyFrom(stats.get_counts())
    if with_filter:
        proto_stats.Filter.CopyFrom(stats.get_filter())
    proto_stats.SegmentInfo.CopyFrom(stats.get_segment_info())
    proto_stats.GroupID = stats.get_group_id()

    return proto_stats


def _get_empty_stats():
    return _prepare_stats({})


def _parse_yt_time(time):
    return dateutil.parser.parse(time, ignoretz=True)


def _related_goals(meta, segment_id):
    return [str(each) for each in meta[InputBatch.Meta.SEGMENTS][segment_id][InputBatch.Meta.RELATED_GOALS]]


def _get_goals_schema():
    return schema_utils.yt_schema_from_dict({
        InputBatch.Fields.SEGMENT_ID: "string",
        UserData.Fields.YUID: "string",
    })


def _with_segment_id(mapping, shift=0, string_format=True):
    @yt.with_context
    def mapper(record, context):
        segment_id = mapping[shift+(context.table_index or 0)]
        id_value = record[Input.Fields.ID_VALUE]
        if string_format:
            yield {
                InputBatch.Fields.SEGMENT_ID: segment_id,
                UserData.Fields.YUID: id_value
            }
        else:
            yield {
                "segment_id": yson.YsonUint64(int(segment_id)),
                "user_id": yson.YsonUint64(int(id_value))
            }
    return mapper


def _prepare_goals_users(task, goals, destination, string_format):
    def _get_path(goal):
        path = task.ypath_join(conf.paths.audience.related_goals, str(int(goal) % 10), goal)
        if not task.yt.exists(path):
            path = task.ypath_join(conf.paths.audience.related_goals, goal)
            if not task.yt.exists(path):
                path = None

        return path

    goal_paths = filter(
        lambda path: path is not None,
        [_get_path(goal) for goal in set(goals)]
    )
    index_to_table_mapping = _index_to_table_mapping(goal_paths, proto=False)

    from_, to_ = 0, MAX_INPUT_PATHS_COUNT

    while from_ < len(goal_paths):
        mapper = _with_segment_id(index_to_table_mapping, shift=from_, string_format=string_format)
        task.map(
            mapper=mapper,
            source=goal_paths[from_:to_],
            destination=destination,
        )
        destination = task.yt.TablePath(destination, append=True)
        from_, to_ = to_, to_ + MAX_INPUT_PATHS_COUNT
    return goal_paths > 0


def _compute_stats_per_goal(task, goals):
    with task.yt.TempTable(prefix='stats_per_goal_', attributes={"schema": _get_goals_schema()}) as temporary:
        if not _prepare_goals_users(task, goals, temporary, string_format=True):
            return {}

        return _compute_stats_per_segment(
            task, temporary, group_by=InputBatch.Fields.SEGMENT_ID, compute_uniq_id_value=False,
        )[0]


def _compute_stats_per_segment(
    task,
    users,
    group_by,
    compute_uniq_id_value=True,
    id_value_column=Matching.Fields.ID_VALUE,
    id_type=id_type_pb2.EIdType.YANDEXUID,
    user_data_stats_options=None,
    experiment="by_crypta_id",
):
    suffix = '_by_{}'.format(group_by)

    with task.yt.TempTable(prefix='user_sets' + suffix) as user_sets_temporary:
        sampler.create_user_set_from_sample(
            task.yt,
            task.native_map_reduce_with_combiner,
            task.native_map,
            source=users,
            destination=user_sets_temporary,
            tvm_settings={
                "source_id": conf.proto.Tvm.TvmId,
                "destination_id": conf.proto.Siberia.Tvm.DestinationTvmId,
                "secret": conf.proto.Tvm.TvmSecret,
            },
            group_id_column=group_by,
            id_type=id_type,
            id_column=UserData.Fields.YUID,
            sample_size=conf.proto.Options.SiberiaSampling.SampleSize,
            siberia_host=conf.proto.Siberia.Host,
            siberia_port=conf.proto.Siberia.Port,
            max_ids_per_second=conf.proto.Options.SiberiaSampling.MaxIdsPerSecond,
            max_jobs=conf.proto.Options.SiberiaSampling.MaxDescribeJobs,
            user_data_stats_options=user_data_stats_options,
            describing_mode=describing_mode.SLOW,
            experiment=experiment,
        )

        tvm_settings = tvmauth.TvmApiClientSettings(
            self_tvm_id=conf.proto.Tvm.TvmId,
            self_secret=conf.proto.Tvm.TvmSecret,
            dsts={"siberia": conf.proto.Siberia.Tvm.DestinationTvmId},
            localhost_port=tvm_helpers.get_tvm_test_port(),
        )

        tvm_client = tvmauth.TvmClient(tvm_settings)
        siberia_client = SiberiaClient(conf.proto.Siberia.Host, conf.proto.Siberia.Port)
        tvm_ticket = tvm_client.get_service_ticket_for("siberia")

        all_stats = sample_stats_getter.get_stats(
            task.yt,
            siberia_client,
            tvm_ticket,
            user_sets_temporary,
        )

        stats_per_segment = {}
        scales_per_segment = collections.defaultdict(lambda: 1.)

        for row in task.yt.read_table(user_sets_temporary):
            stats = all_stats[row[sample_stats_getter.USER_SET_ID]].UserDataStats

            filter = row.get("filter")
            if filter:
                stats.Filter.ParseFromString(filter)

            stats_per_segment[row[group_by]] = stats

    if compute_uniq_id_value:
        rows = _compute_uniq_id_value(task, users, group_by, id_value_column=id_value_column)
        for row in rows:
            stats = stats_per_segment.setdefault(row["group_id"], _get_empty_stats())
            if stats.Counts.UniqYuid != 0:
                # TODO(CRYPTA-13107) maybe upscaling is not needed if segment is smaller than sample size
                scales_per_segment[row["group_id"]] = float(row["uniq_yuid_count"]) / stats.Counts.UniqYuid
            else:
                scales_per_segment[row["group_id"]] = 1.
                stats.Counts.UniqYuid = long(row["uniq_yuid_count"])

            stats.Counts.UniqIdValue = long(row["uniq_id_value_count"])

    return stats_per_segment, scales_per_segment


def _compute_uniq_id_value(task, users, group_by, id_value_column):
    if 0 == task.yt.get_attribute(users, Attributes.ROW_COUNT, 0):
        return []

    with task.yt.TempTable(prefix="uniq_id_count_") as destination:
        executer = yql.get_executer(task.proxy, task.pool or conf.yt.pool, conf.paths.audience.tmp)

        query = COMPUTE_UNIQ_ID_VALUE_QUERY_TEMPLATE.format(
            source=users,
            group_by=group_by,
            id_value_column=id_value_column,
            destination=destination,
        )

        executer(query, transaction=task.transaction_id, syntax_version=1)

        return list(task.yt.read_table(destination))


def _index_to_table_mapping(tables, proto=True):
    mapping = TMapping()
    for (i, table) in enumerate(tables):
        mapping.Mapping[i] = os.path.basename(str(table))
    return mapping if proto else dict(mapping.Mapping)


def _foreign(table):
    return yt.TablePath(table, foreign=True)


def _batches(tables):
    batch = []
    row_count = 0
    batch_oldest_ts = datetime.datetime.max

    for table in tables:
        batch.append(table)
        row_count += table.attributes.get(Attributes.ROW_COUNT, 0)
        batch_oldest_ts = min(_parse_yt_time(table.attributes.get(Attributes.MODIFICATION_TIME)), batch_oldest_ts)

        if row_count >= conf.proto.Options.Input.MinBatchSizeInRows or len(batch) >= conf.proto.Options.Input.MaxBatchSizeInSegments:
            yield batch
            batch = []
            row_count = 0
            batch_oldest_ts = datetime.datetime.max

    now = datetime.datetime.utcnow()
    if batch and (now - batch_oldest_ts) > datetime.timedelta(seconds=conf.proto.Options.Input.IncompleteBatchMinAgeSec):
        yield batch


def _get_last_table(task, directory, table):
    DATE_PATTERN = re.compile(r'(\d{4}-\d{2}-\d{2})')
    dates = sorted(filter(lambda x: DATE_PATTERN.match(x), task.yt.list(directory)))

    for date in dates[::-1]:
        path = task.ypath_join(directory, date, table)
        if task.yt.exists(path) and task.yt.row_count(path) > 0:
            return path
    raise ValueError("Can't find matching table")


def _output_stats(stats, scale_factor=1.0):
    def scale(value):
        return long(round(value * scale_factor))

    return {
        Output.Stats.UNIQ_ID_VALUE: stats.Counts.UniqIdValue,
        Output.Stats.UNIQ_YUID: scale(stats.Counts.UniqYuid),
        Output.Stats.SEX: {
            str(each.Gender - 1): scale(each.Count)
            for each in stats.Attributes.Gender
            if each.Gender != 0
        },
        Output.Stats.AGE: {
            str(each.Age - 1): scale(each.Count)
            for each in stats.Attributes.Age
            if each.Age != 0
        },
        Output.Stats.REGION: {
            str(each.Region): scale(each.Count)
            for each in stats.Attributes.Region
            if each.Region != 0
        },
        Output.Stats.DEVICE: {
            str(each.Device - 1): scale(each.Count)
            for each in stats.Attributes.Device
            if each.Device != 0
        },
    }


def _collect_output_stats(audience_id, segment_type, this_segment_stats, this_scale_factor,
                          global_stats, similarity_scores,
                          communality_quantiles, trace=None):
    if trace is None:
        trace = _raw_trace(this_segment_stats)
    output_stats = {}
    output_stats[Output.Attributes.SEGMENT_ID] = audience_id
    output_stats[
        Output.Attributes.CRYPTA_RELATED_GOALS_SIMILARITY] = similarity_scores
    output_stats[Output.Attributes.COMMUNALITY] = _compute_communality(
        this_segment_stats, segment_type, communality_quantiles)
    output_stats[Output.Attributes.COVARIANCE_TRACE] = trace
    output_stats[Output.Attributes.INTERESTS_AFFINITY] = \
        _compute_affinities(Output.CRYPTA_INTERESTS, this_segment_stats, global_stats)
    output_stats[Output.Attributes.SEGMENTS_AFFINITY] = \
        _compute_affinities(Output.CRYPTA_SEGMENTS, this_segment_stats, global_stats)
    output_stats[Output.Attributes.OVERALL_STATS] = _output_stats(
        this_segment_stats, this_scale_factor)
    output_stats[Output.Attributes.STATUS] = Output.Statuses.DONE
    return output_stats


def _create_table(yt, path, schema=None, recursive=True, ignore_existing=True):
    kwargs = dict(
        recursive=recursive,
        ignore_existing=ignore_existing
    )
    if schema is not None:
        kwargs.update(attributes={Attributes.SCHEMA: GeneralStorageOutput.SCHEMA})
    yt.create('table', path, **kwargs)


class BatchifyInputs(YtTask, IndependentTask):
    def run(self, **kwargs):
        input_path = conf.paths.audience.input

        if not self.yt.exists(input_path):
            logger.warning('%s does not exist', input_path)
            return

        segment_types = {
            table.attributes.get(Input.Attributes.CRYPTA_SEGMENT_INFO, {}).get("segment_type")
            for table in self.yt.search(input_path, node_type='table', attributes=[Input.Attributes.CRYPTA_SEGMENT_INFO])
        }

        for segment_type in segment_types:
            shard_count = conf.proto.Options.Input.DefaultShardCount
            if segment_type == "uploading":
                shard_count = conf.proto.Options.Input.UploadingShardCount
            elif segment_type == "geo":
                shard_count = conf.proto.Options.Input.GeoShardCount

            for shard in range(shard_count):
                yield BatchifyInputsWithSegmentType(segment_type=segment_type, shard_count=shard_count, shard=shard)


class BatchifyInputsWithSegmentType(YtTask, IndependentTask):

    segment_type = Parameter()
    shard_count = Parameter(parse=int, default=1)
    shard = Parameter(parse=int, default=0)

    def _output_error_table(self, id, reason, status):
        output_path = conf.paths.audience.output
        table_path = self.ypath_join(output_path, id)

        logger.warning('Segment %s is wrong: %s', id, reason)

        if not self.yt.exists(output_path):
            self.yt.create('map_node', output_path, recursive=True)
        self.yt.create('table', table_path, ignore_existing=True)
        self.yt.set_attribute(table_path, Output.Attributes.ERROR, reason)
        self.yt.set_attribute(table_path, Output.Attributes.STATUS, status)
        self.yt.set_attribute(table_path, Output.Attributes.OVERALL_STATS, _output_stats(UserDataStats.Proto()))

    def _arrange_inputs(self, all_input_tables):
        wrong_inputs = []
        segment_database_records = []
        input_tables_per_batching_key = collections.defaultdict(list)

        for table in all_input_tables:
            attributes = table.attributes
            row_count = attributes.get(Attributes.ROW_COUNT, None)
            modification_time = _parse_yt_time(attributes.get(Attributes.MODIFICATION_TIME, None))
            matching_type = attributes.get(Input.Attributes.MATCHING_TYPE, None)
            id_type = attributes.get(Input.Attributes.ID_TYPE, None)
            status = attributes.get(Input.Attributes.STATUS, None)
            device_matching_type = attributes.get(Input.Attributes.DEVICE_MATCHING_TYPE, None)
            audience_segment_id = attributes.get(Input.Attributes.SEGMENT_ID, None)
            source_id = attributes.get(Input.Attributes.CRYPTA_SEGMENT_INFO, {}).get("source_id")
            segment_id = os.path.basename(table)

            if status != Output.Statuses.NEW:
                logger.info('Skipping %s with status %s', table, status)
                continue

            if not audience_segment_id:
                logger.warning('Skipping %s with no %s', table, Input.Attributes.SEGMENT_ID)
                continue

            try:
                matching_table = Matching.get(matching_type, id_type, device_matching_type)
            except InvalidMatchingParameters as e:
                wrong_inputs.append(WrongInput(id=segment_id, table=table, reason=str(e), status=Output.Statuses.FAILED))
                continue

            if row_count == 0:
                wrong_inputs.append(WrongInput(id=segment_id, table=table, reason='Empty', status=Output.Statuses.DONE))
                continue

            input_tables_per_batching_key[self._create_batching_key(matching_table, source_id, id_type)].append(table)

            segment_database_records.append(
                SegmentDatabaseRecord(id=segment_id, audience_id=audience_segment_id, row_count=row_count, time=modification_time)
            )

        return input_tables_per_batching_key, segment_database_records, wrong_inputs

    def _find_inputs(self, path):
        return self.yt.search(
            path,
            node_type='table',
            attributes=[
                Attributes.ROW_COUNT,
                Attributes.MODIFICATION_TIME,
                Input.Attributes.SEGMENT_ID,
                Input.Attributes.SEGMENT_PRIORITY,
                Input.Attributes.STATUS,
                Input.Attributes.ID_TYPE,
                Input.Attributes.MATCHING_TYPE,
                Input.Attributes.DEVICE_MATCHING_TYPE,
                Input.Attributes.CRYPTA_RELATED_GOALS,
                Input.Attributes.CRYPTA_SEGMENT_INFO,
            ],
            object_filter=lambda obj: obj.attributes.get(Input.Attributes.CRYPTA_SEGMENT_INFO, {}).get("segment_type") == self.segment_type,
        )

    @staticmethod
    def _priority_desc_time_asc(table):
        priority = int(table.attributes.get(Input.Attributes.SEGMENT_PRIORITY, 0))
        modification_time = table.attributes.get(Attributes.MODIFICATION_TIME)
        return (-priority, modification_time)

    @staticmethod
    def _create_batching_key(matching_table, source_id, content_type):
        if source_id in CRYPTAID_SOURCEID:
            target_type = Matching.CRYPTA_ID
        elif content_type in {Matching.CRYPTA_ID, Matching.PUID, Matching.YUID}:
            target_type = content_type
        else:
            target_type = Matching.YUID

        return BatchingKey(matching_table=matching_table, target_type=target_type)

    def run(self, **kwargs):
        input_path = conf.paths.audience.input

        if self.shard >= self.shard_count:
            raise Exception('Invalid configuration')

        if not self.yt.exists(input_path):
            logger.warning('%s does not exist', input_path)
            return

        input_tables = self._find_inputs(input_path)
        shard_input_tables = [
            table for table in input_tables
            if (hash(str(table)) % self.shard_count == self.shard)
        ]

        input_tables_per_batching_key, segment_database_records, wrong_inputs = self._arrange_inputs(shard_input_tables)

        for batching_key, these_input_tables in input_tables_per_batching_key.items():

            these_input_tables_ordered = sorted(these_input_tables, key=self._priority_desc_time_asc)
            # just a few batches at a time
            batches = list(_batches(these_input_tables_ordered))[:conf.proto.Options.Input.NumBatches]

            for these_input_tables_batch in batches:
                batch_path = self.ypath_join(conf.paths.audience.batches, str(uuid.uuid4()))
                index_to_table_mapping = _index_to_table_mapping(these_input_tables_batch)
                self.native_map(TBatchify, these_input_tables_batch, batch_path, index_to_table_mapping.SerializeToString())
                segments = {
                    os.path.basename(each): {
                        InputBatch.Meta.RELATED_GOALS: list(each.attributes.get(Input.Attributes.CRYPTA_RELATED_GOALS, [])),
                        InputBatch.Meta.AUDIENCE_SEGMENT_ID: each.attributes.get(Input.Attributes.SEGMENT_ID, None),
                        InputBatch.Meta.SEGMENT_INFO: each.attributes.get(Input.Attributes.CRYPTA_SEGMENT_INFO, {}),
                        InputBatch.Meta.ID_TYPE: each.attributes.get(Input.Attributes.ID_TYPE, None),
                        InputBatch.Meta.DEVICE_MATCHING_TYPE: each.attributes.get(Input.Attributes.DEVICE_MATCHING_TYPE, None),
                    } for each in these_input_tables_batch
                }

                self.yt.set_attribute(batch_path, InputBatch.Meta.META, {
                    InputBatch.Meta.MATCHING_TABLE: batching_key.matching_table,
                    InputBatch.Meta.SEGMENTS: segments,
                    InputBatch.Meta.TARGET_TYPE: batching_key.target_type,
                })
                priority = sum(int(table.attributes.get(Input.Attributes.SEGMENT_PRIORITY, 0)) for table in these_input_tables_batch)
                self.yt.set_attribute(batch_path, InputBatch.Attributes.PRIORITY, priority)

                for each in these_input_tables_batch:
                    self.yt.remove(each)

        with RegularSegmentStateStorage(self._init_yt(), experiment_mode=True).batched_inserter as state_inserter:
            for segment_id, table, reason, status in wrong_inputs:
                self._output_error_table(segment_id, reason, status)
                self.yt.remove(table)
                row_count = table.attributes.get(Attributes.ROW_COUNT, 0L)
                modification_time = _parse_yt_time(table.attributes.get(Attributes.MODIFICATION_TIME, None))
                audience_segment_id = table.attributes.get(Input.Attributes.SEGMENT_ID, None)
                write_segment_on_input(segment_id=audience_segment_id, row_count=row_count, time=modification_time, state_inserter=state_inserter)
                write_segment_on_output(segment_id=audience_segment_id, row_count=0L, time=modification_time, state_inserter=state_inserter)

            # it is more safe to dump all the stuff in the end because SQL transaction time is limited
            for segment in segment_database_records:
                write_segment_on_input(segment_id=segment.audience_id, row_count=segment.row_count, time=segment.time, state_inserter=state_inserter)


class EnqueueBatches(YtTask, IndependentTask):

    def run(self, **kwargs):
        batches_path = conf.paths.audience.batches

        if not self.yt.exists(batches_path):
            logger.warning('Path %s does not exist', batches_path)
            return

        for priority, batch in _get_list_batches_with_priorities(self, batches_path, table_only=True):
            yield ProcessBatch(batch=batch, priority=priority)


class EnqueueHeavyBatch(YtTask, IndependentTask):

    def run(self, **kwargs):
        batches_path = conf.paths.audience.batches

        if not self.yt.exists(batches_path):
            logger.warning('Path %s does not exist', batches_path)
            return

        skip_filter = lambda batch: self.yt.get_attribute(batch, InputBatch.Attributes.IN_PROCESSING, False)
        batches = list(_get_list_batches_with_priorities(self, batches_path, dir_only=True, skip_filter=skip_filter))

        for priority, batch in itertools.islice(sorted(batches, reverse=True), 4):
            yield ProcessBatch(batch=batch, priority=priority)


class SplitMapper(object):

    def __init__(self, segment_to_index_mapping, ts):
        self.segment_to_index_mapping = segment_to_index_mapping
        self.ts = ts

    def __call__(self, record):
        valid = False
        try:
            yson.YsonUint64(record[Output.Fields.YUID])
            valid = True
        except ValueError:
            pass

        if valid:
            segment = record.pop(InputBatch.Fields.SEGMENT_ID)
            record.pop(InputBatch.Fields.EXTERNAL_ID, None)
            record[Output.Fields.SEND] = 0
            yield yt.create_table_switch(self.segment_to_index_mapping[segment])
            yield record


class ProcessBatch(YQLTaskV1, YtTask, IndependentTask):

    batch = Parameter()
    timeout = datetime.timedelta(hours=12)

    @staticmethod
    def _segments_schema():
        return schema_utils.yt_schema_from_dict({
            Output.Fields.YUID: "string",
            Output.Fields.ID_VALUE: "string",
            InputBatch.Fields.SEGMENT_ID: "string",
            InputBatch.Fields.EXTERNAL_ID: "string",
        })

    @property
    def query(self):
        return

    @cached_property
    def stats_storage(self):
        stats_storage = StatsStorage(self._init_yt())
        stats_storage.prepare_table()
        return stats_storage

    def get_segment_type(self, segment_id):
        return 'audience'

    def extract_meta_from_table(self):
        batch_meta = InputBatch.meta(self.batch)

        matching_table = None
        segments = {}
        for record in self.yt.read_table(batch_meta):
            segment_id = str(record['segment_id'])
            meta = record['meta']
            device_matching_type = meta.get(
                Input.Attributes.DEVICE_MATCHING_TYPE, None)
            id_type = meta.get(Input.Attributes.ID_TYPE, None)
            options = {
                InputBatch.Meta.RELATED_GOALS: list(meta.get(
                    Input.Attributes.CRYPTA_RELATED_GOALS, [])),
                InputBatch.Meta.AUDIENCE_SEGMENT_ID: meta.get(
                    Input.Attributes.SEGMENT_ID, None),
                InputBatch.Meta.SEGMENT_INFO: meta.get(
                    Input.Attributes.CRYPTA_SEGMENT_INFO, {}),
                InputBatch.Meta.ID_TYPE: id_type,
                InputBatch.Meta.DEVICE_MATCHING_TYPE: device_matching_type,
            }
            segments[segment_id] = options
            if matching_table is None:
                matching_type = meta.get(Input.Attributes.MATCHING_TYPE, None)
                matching_table = Matching.get(matching_type, id_type, device_matching_type)

        return {
            InputBatch.Meta.MATCHING_TABLE: matching_table,
            InputBatch.Meta.SEGMENTS: segments,
        }

    @staticmethod
    def extract_from_segment(segment, field):
        return segment.get(InputBatch.Meta.SEGMENT_INFO, {}).get(field)

    def try_to_add_ids_to_general_storage(self):
        state = TGeneralStorageState()
        state.Timestamp = self.ts

        for segment_id, segment_meta in self.meta[InputBatch.Meta.SEGMENTS].items():
            id_type = self.extract_from_segment(segment_meta, 'content_type')

            if id_type in {Matching.PHONE, Matching.EMAIL, Matching.IDFA_GAID, Matching.CRM}:
                table_index = 0
                if id_type in {Matching.PHONE, Matching.EMAIL, Matching.CRM}:
                    table_index = 1
                segment = state.Segments[str(segment_id)]
                segment.SegmentID = segment_meta.get(InputBatch.Meta.AUDIENCE_SEGMENT_ID)
                segment.SegmentType = self.extract_from_segment(segment_meta, 'segment_type')
                segment.IdType = id_type
                segment.TableIndex = table_index

        if len(state.Segments) > 0:
            device_storage = os.path.join(conf.paths.storage.device_queue,
                                          os.path.basename(self.batch))
            email_phone_storage = os.path.join(conf.paths.storage.email_phone_queue,
                                               os.path.basename(self.batch))

            _create_table(self.yt, device_storage, GeneralStorageOutput.SCHEMA)
            _create_table(self.yt, email_phone_storage, GeneralStorageOutput.SCHEMA)

            self.native_map(
                TExtractGeneralStorageOutput,
                self.batch_users,
                [device_storage, email_phone_storage],
                state=state.SerializeToString()
            )

    def run(self, **kwargs):
        self.yt.lock(self.batch, mode="exclusive")
        if not self.yt.exists(self.batch):
            logger.warning('Batch %s disappeared', self.batch)
            return

        self.try_to_add_ids_to_general_storage()
        self.create_output()

        with self.yt.TempTable(prefix='matched_', attributes={"schema": self._segments_schema()}) as results:
            self.prepare_matched(results)
            self.compute_stats(results)
            self.output_storage_meta()
            self.output_storage(results)
            self.split_to_output(results)

        self.write_output_stats_and_goal_similarity()
        self.write_to_db()
        self.yt.remove(self.batch, recursive=True)

    @cached_property
    def input_is_table(self):
        return self.yt.get_attribute(self.batch, 'type') == 'table'

    @cached_property
    def batch_users(self):
        return self.batch if self.input_is_table else InputBatch.users(self.batch)

    @cached_property
    def meta(self):
        return self.yt.get_attribute(self.batch, InputBatch.Meta.META, None) if self.input_is_table else self.extract_meta_from_table()

    @cached_property
    def ts(self):
        return time_utils.get_current_time()

    @cached_property
    def segments_ids(self):
        return self.meta[InputBatch.Meta.SEGMENTS].keys()

    @cached_property
    def segments_mapping(self):
        return {key: str(value[InputBatch.Meta.AUDIENCE_SEGMENT_ID]) for key, value in self.meta[InputBatch.Meta.SEGMENTS].items()}

    @cached_property
    def output_tables(self):
        return [self.ypath_join(conf.paths.audience.output, segment_id) for segment_id in self.segments_ids]

    @cached_property
    def global_stats(self):
        userdata_stats_path = conf.paths.lab.data.crypta_id.userdata_stats
        global_stats = _prepare_stats(next(self.yt.read_table(userdata_stats_path)))
        assert global_stats
        return global_stats

    @cached_property
    def id_type(self):
        if InputBatch.Meta.TARGET_TYPE in self.meta:
            return Matching.TARGET_TYPE_TO_ID_TYPE[self.meta[InputBatch.Meta.TARGET_TYPE]]
        elif all(self.meta[InputBatch.Meta.SEGMENTS][segment_id][Input.Attributes.CRYPTA_SEGMENT_INFO][
                Input.Attributes.SOURCE_ID] in CRYPTAID_SOURCEID for segment_id in self.segments_ids):
            return id_type_pb2.EIdType.CRYPTA_ID
        else:
            return id_type_pb2.EIdType.YANDEXUID

    @cached_property
    def storage_path(self):
        target_type_to_storage_params = {
            id_type_pb2.EIdType.CRYPTA_ID: (conf.paths.storage.crypta_id_queue, tables_pb2.TCryptaIdStorageOutput),
            id_type_pb2.EIdType.YANDEXUID: (conf.paths.storage.for_full, tables_pb2.TYandexuidStorageOutput),
            id_type_pb2.EIdType.PUID: (conf.paths.storage.puid_queue, tables_pb2.TPuidStorageOutput),
        }
        root_path, proto = target_type_to_storage_params[self.id_type]
        return self.yt.TablePath(
            yt.ypath_join(root_path, os.path.basename(self.batch)),
            schema=schema_utils.get_schema_from_proto(proto),
        )

    @cached_property
    def storage_id_field(self):
        target_type_to_id_field = {
            id_type_pb2.EIdType.CRYPTA_ID: "CryptaID",
            id_type_pb2.EIdType.YANDEXUID: "yandexuid",
            id_type_pb2.EIdType.PUID: "Puid",
        }
        return target_type_to_id_field[self.id_type]

    def create_output(self):
        for output_table in self.output_tables:
            self.yt.remove(output_table, force=True)
            self.yt.create('table', self.yt.TablePath(output_table, schema=Output.SCHEMA, compression_codec='none'), recursive=True)

    def prepare_matched(self, results):
        matching_table_path = self.meta[InputBatch.Meta.MATCHING_TABLE]
        logger.info('Matching with %s', matching_table_path)

        if self.yt.get_attribute(self.batch_users, "row_count") == 0:
            return
        elif matching_table_path:
            self.match(results, matching_table_path)
        elif self.id_type == id_type_pb2.EIdType.CRYPTA_ID:
            self.filter_crypta_ids(results)
        else:
            self.add_id_value(results)

    def match(self, results, matching_table_path):
        self.yql_client.execute(
            templater.render_resource('/query/match.yql', strict=True, vars={
                "input_table": self.batch_users,
                "output_table": results,
                "matching_table": matching_table_path,
                "segments_mapping": self.segments_mapping,
            }),
        )

    def filter_crypta_ids(self, results):
        self.yql_client.execute(
            templater.render_resource('/query/filter_crypta_ids.yql', strict=True, vars={
                "input_table": self.batch_users,
                "output_table": results,
                "crypta_id_filter": conf.paths.audience.matching.cryptaids,
                "segments_mapping": self.segments_mapping,
            }),
        )

    def add_id_value(self, results):
        self.yql_client.execute(
            templater.render_resource('/query/add_id_value.yql', strict=True, vars={
                "input_table": self.batch_users,
                "output_table": results,
                "segments_mapping": self.segments_mapping,
            }),
        )

    def compute_stats(self, results):
        self.stats_per_segment, self.scales_per_segment = _compute_stats_per_segment(
            self,
            results,
            group_by=InputBatch.Fields.SEGMENT_ID,
            id_type=self.id_type,
        )

        all_related_goals = sum((_related_goals(self.meta, each) for each in self.segments_ids), [])
        self.stats_per_goal = _compute_stats_per_goal(self, all_related_goals)

    def output_storage_meta(self):
        storage_meta_output = os.path.join(conf.paths.storage.queue_segments_info, os.path.basename(self.batch))
        self.yt.create('table', storage_meta_output, recursive=True)
        self.yt.write_table(storage_meta_output, (
            segment_meta.get(InputBatch.Meta.SEGMENT_INFO, {})
            for segment_meta in self.meta.get(InputBatch.Meta.SEGMENTS, {}).values()
        ))

    def output_storage(self, results):
        self.yt.create('table', self.storage_path, recursive=True)

        excluded_segments = set()
        for segment_id, segment_meta in self.meta.get(InputBatch.Meta.SEGMENTS, {}).items():
            segment_info = segment_meta.get(InputBatch.Meta.SEGMENT_INFO, {})
            is_geo = (segment_info.get('segment_type') == 'geo')
            is_geo_condition = (segment_info.get('geo_segment_type') == 'condition')
            has_non_direct_retargeting = any(segment_info.get(type) == "1" for type in ('adfox_retargeting', 'banana_retargeting', 'display_retargeting', 'geoadv_retargeting', 'zen_retargeting'))
            if is_geo and not (is_geo_condition or has_non_direct_retargeting):
                audience_segment_id = segment_meta.get(InputBatch.Meta.AUDIENCE_SEGMENT_ID)
                excluded_segments.add(audience_segment_id)

        executer = yql.get_executer(self.proxy, self.pool or conf.yt.pool, conf.paths.audience.tmp)

        query = templater.render_resource('/query/output_storage.yql', strict=True, vars={
            "input_table": results,
            "output_table": self.storage_path,
            "timestamp": self.ts,
            "excluded_segments": excluded_segments,
            "field": self.storage_id_field,
        })

        executer(query, transaction=self.transaction_id, syntax_version=1)

    def split_to_output(self, results):
        self.sort(results, results, sort_by=InputBatch.Fields.SEGMENT_ID)

        segment_to_index_mapping = {segment_id: i for (i, segment_id) in enumerate(self.segments_ids)}
        self.map(
            SplitMapper(
                segment_to_index_mapping=segment_to_index_mapping,
                ts=self.ts,
            ),
            results,
            self.output_tables,
        )

    def write_to_db(self):
        output_ts = time_utils.get_current_moscow_datetime()
        storage_client = self._init_yt()

        with RegularSegmentStateStorage(storage_client, experiment_mode=True).batched_inserter as state_inserter, \
                SegmentPropertiesStorage(storage_client, experiment_mode=True).batched_inserter as properties_inserter,\
                StatsStorage(storage_client).batched_inserter as stats_inserter:
            for segment_id, output_table in zip(self.segments_ids, self.output_tables):
                audience_id = self.meta[InputBatch.Meta.SEGMENTS][segment_id][InputBatch.Meta.AUDIENCE_SEGMENT_ID]
                row_count = self.yt.get_attribute(output_table, Attributes.ROW_COUNT)
                segment_type = self.get_segment_type(segment_id)
                this_segment_stats = self.stats_per_segment.get(segment_id, _get_empty_stats())

                write_segment_on_output(
                    segment_id=audience_id,
                    row_count=row_count,
                    time=output_ts,
                    state_inserter=state_inserter,
                )
                write_segment_properties(
                    segment_id=audience_id,
                    communality=_raw_trace(this_segment_stats),
                    segment_type=segment_type,
                    inserter=properties_inserter,
                    time=output_ts,
                )
                stats_inserter.insert_row(
                    audience_id,
                    segment_type,
                    this_segment_stats.SerializeToString(),
                    self.ts,
                )

    def write_output_stats_and_goal_similarity(self):
        storage_client = self._init_yt()
        with SegmentsGoalsRelationStorage(storage_client, experiment_mode=True).batched_inserter as relation_inserter:
            communality_quantiles = SegmentPropertiesStorage(storage_client, experiment_mode=True).get_communality_quantiles()

            for segment_id, output_table in zip(self.segments_ids, self.output_tables):
                segment_type = self.get_segment_type(segment_id)
                audience_id = self.meta[InputBatch.Meta.SEGMENTS][segment_id][InputBatch.Meta.AUDIENCE_SEGMENT_ID]

                this_segment_stats = self.stats_per_segment.get(segment_id, _get_empty_stats())
                raw_similarities = {
                    str(goal): _similarity(this_segment_stats, self.stats_per_goal.get(goal, self.global_stats))
                    for goal in _related_goals(self.meta, segment_id)
                }
                for goal_id, similarity in raw_similarities.iteritems():
                    write_segment_goal_relation(segment_id=audience_id, goal_id=goal_id, cosine=similarity, inserter=relation_inserter)
                similarity_scores = {goal: _cosine_quantile_score(similarity) for (goal, similarity) in raw_similarities.iteritems()}
                output_stats = _collect_output_stats(audience_id, segment_type,
                                                     this_segment_stats,
                                                     self.scales_per_segment.get(segment_id, 1.),
                                                     self.global_stats,
                                                     similarity_scores,
                                                     communality_quantiles)

                if segment_id not in self.stats_per_segment:
                    logger.warning('Segment %s has no stats', segment_id)
                    output_stats[Output.Attributes.ERROR] = "Failed to compute stats"

                for key, value in output_stats.items():
                    self.yt.set_attribute(output_table, key, value)

                logger.info("Output %s", str(output_stats))
