#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import partial
from os.path import join

import luigi

from crypta.profile.utils.config import config
from crypta.profile.utils.segment_utils.processors import DayProcessor, LogProcessor
from crypta.profile.utils.segment_utils.builders import RegularSegmentBuilder
from crypta.profile.utils.luigi_utils import ExternalInput, YtDailyRewritableTarget, BaseYtTask
from crypta.profile.runners.segments.lib.coded_segments.data.visited_organizations import organizations_categories_to_segment_ids


def get_expanded_category_ids(record, category_parent_ids):
    category_ids = set()
    category_ids_with_parents = set()
    output_record = {
        'permalink': record['permalink'],
    }
    for category_info in record['rubrics']:
        category_ids.add(category_info['rubric_id'])

        category_ids_with_parents.add(category_info['rubric_id'])
        if category_info['rubric_id'] in category_parent_ids:
            parent_id = category_parent_ids[category_info['rubric_id']]
            while True:
                category_ids_with_parents.add(parent_id)
                if parent_id not in category_parent_ids:
                    break
                else:
                    parent_id = category_parent_ids[parent_id]

    output_record['categories'] = list(category_ids)
    output_record['categories_with_parents'] = list(category_ids_with_parents)
    for name_info in record['names']:
        if name_info['type'] == 'main':
            if name_info['value']['locale'] == 'ru':
                output_record['name_ru'] = name_info['value']['value']
            if name_info['value']['locale'] == 'en':
                output_record['name_en'] = name_info['value']['value']

    yield output_record


class OrganizationCategoryDictionary(BaseYtTask):
    date = luigi.Parameter()
    task_group = 'coded_segments'

    def requires(self):
        return {
            'category_tree': ExternalInput(config.SPRAV_RUBRICS),
            'organization_categories': ExternalInput(config.SPRAV_COMPANIES),
        }

    def run(self):
        with self.yt.Transaction():
            categories_table = self.input()['category_tree'].table
            category_parent_ids = {}

            for row in self.yt.read_table(
                    self.yt.TablePath(categories_table, columns=['id', 'parent_rubric_id'])
            ):
                if row['parent_rubric_id'] is not None:
                    category_parent_ids[row['id']] = row['parent_rubric_id']

            self.yt.create_empty_table(
                self.output().table,
                schema={
                    'permalink': 'int64',
                    'categories': 'any',
                    'categories_with_parents': 'any',
                    'name_ru': 'string',
                    'name_en': 'string',
                },
                force=True,
            )
            self.yt.run_map(
                partial(get_expanded_category_ids, category_parent_ids=category_parent_ids),
                self.yt.TablePath(
                    self.input()['organization_categories'].table,
                    columns=['permalink', 'names', 'rubrics'],
                ),
                self.output().table,
            )
            self.yt.run_sort(self.output().table, sort_by='permalink')
            self.yt.set_attribute(
                self.output().table,
                'generate_date',
                self.date
            )

    def output(self):
        return YtDailyRewritableTarget(
            join(config.ORGANIZATION_CATEGORIES),
            date=self.date,
        )


day_processor_query = """
INSERT INTO `{output_table}` WITH TRUNCATE
SELECT id, 'mm_device_id' AS id_type, AGGREGATE_LIST(DISTINCT permalink) AS permalinks
FROM `{input_table}`
GROUP BY mmetric_devid AS id
"""


class OrganizationVisitsDayProcessor(DayProcessor):
    def requires(self):
        return ExternalInput(join(config.ORG_VISITS_DIRECTORY, self.date))

    def process_day(self, inputs, output_path):
        self.yql.query(
            day_processor_query.format(
                input_table=inputs.table,
                output_table=output_path,
            ),
            transaction=self.transaction,
        )


get_segments_query_template = """
$permalinks_factory = AggregationFactory(
    "UDAF",
    ($item, $parent) -> {{ return ToSet($item) ?? SetCreate(Int64) }},
    ($state, $item, $parent) -> {{ return SetUnion($state, ToSet($item) ?? SetCreate(Int64)) }},
    ($state1, $state2) -> {{ return SetUnion($state1, $state2) }},
    ($state) -> {{ return DictKeys($state) }},
);

$merged = (
    SELECT id, id_type, permalink
    FROM (
        SELECT id, id_type, AGGREGATE_BY(permalinks, $permalinks_factory) AS permalinks
        FROM (
            SELECT id, id_type, Yson::ConvertToInt64List(permalinks) AS permalinks
            FROM `{merged_visits_table}`
        )
        GROUP BY id, id_type
    )
    FLATTEN LIST BY permalinks AS permalink
);

$categories = (
    SELECT id, id_type, organization_categories.categories_with_parents AS categories
    FROM $merged AS merged
    INNER JOIN `{organization_categories}` AS organization_categories
    USING (permalink)
);

$script = @@
organizations_categories_to_segment_ids = {categories_to_segments}
interesting_org_categories = set(organizations_categories_to_segment_ids.keys())

def get_categories(categories):
    categories_intersection = interesting_org_categories.intersection(categories)
    if categories_intersection:
        return [str(category) for category in categories_intersection]
    else:
        return None
@@;

$get_categories = Python2::get_categories(Callable<(List<Int64>?)->List<String>?>, $script);

INSERT INTO `{output_table}` WITH TRUNCATE
SELECT id, id_type, segment_name
FROM (
    SELECT id, id_type, segment_name
    FROM (
        SELECT
            id,
            id_type,
            $get_categories(Yson::ConvertToInt64List(categories)) AS segment_names
        FROM $categories
        WHERE $get_categories(Yson::ConvertToInt64List(categories)) IS NOT NULL
    )
    FLATTEN LIST BY segment_names AS segment_name
)
GROUP BY id, id_type, segment_name;
"""


class OrganizationVisitors(RegularSegmentBuilder):
    name_segment_dict = {
        str(category): segment_id for category, segment_id in organizations_categories_to_segment_ids.iteritems()
    }

    keyword = 547
    number_of_days = 35

    def requires(self):
        return {
            'OrgDictionary': OrganizationCategoryDictionary(self.date),
            'OrgVisits': LogProcessor(
                OrganizationVisitsDayProcessor,
                self.date,
                self.number_of_days,
            ),

        }

    def build_segment(self, inputs, output_path):
        self.yql.query(
            query_string=get_segments_query_template.format(
                merged_visits_table=inputs['OrgVisits'].table,
                organization_categories=inputs['OrgDictionary'].table,
                categories_to_segments=organizations_categories_to_segment_ids,
                output_table=output_path
            ),
            transaction=self.transaction,
        )
