#!/usr/bin/env python
# -*- coding: utf-8 -*-

from crypta.profile.utils import luigi_utils
from crypta.profile.utils.config import config
from crypta.profile.utils.segment_utils.builders import (
    LaLParams,
    RegularSegmentBuilder,
)

industry_to_segment_name = {
    u'Медицина': u'medicine',
    u'Путешествия и туризм': u'tourism',
    u'Кафе, рестораны, доставка еды': u'food',
}

segment_query = u"""
$industry_to_segment_name = AsDict(
{industry_to_segment_name}
);

$client_id_with_industry = (
    SELECT
        client_id AS ClientID,
        CASE
            WHEN DictContains($industry_to_segment_name, String::Strip(industry)) THEN $industry_to_segment_name[String::Strip(industry)]
            ELSE NULL
        END AS segment_name
    FROM `{direct_industries}`
    GROUP BY client_id, curr_counterparty_industry.industry AS industry
);

$puids_with_industries = (
    SELECT
        CAST(puids.uid AS String) AS id,
        'puid' AS id_type,
        industries.segment_name AS segment_name
    FROM `{direct_users}` AS puids
    INNER JOIN $client_id_with_industry AS industries
    USING(ClientID)
    WHERE industries.segment_name IS NOT NULL
    GROUP BY puids.uid, industries.segment_name
);

$yandexuids = (
    SELECT
        CAST(yandexuids.yandexuid AS String) AS id,
        puids.segment_name AS segment_name,
        'yandexuid' AS id_type
    FROM $puids_with_industries AS puids
    INNER JOIN `{indevice_yandexuid_matching}` AS yandexuids
    USING (id, id_type)
);

INSERT INTO `{output_table}` WITH TRUNCATE
SELECT *
FROM $yandexuids
WHERE segment_name == 'food';

INSERT INTO `{sample_table}` WITH TRUNCATE
SELECT *
FROM $yandexuids
WHERE segment_name IN AsSet('medicine', 'tourism');
"""


class DirectClientsByIndustry(RegularSegmentBuilder):
    name_segment_dict = {
        'food': (549, 1984),
    }

    def requires(self):
        return {
            'DirectUsers': luigi_utils.ExternalInput(config.DIRECT_USERS),
            'DirectIndustries': luigi_utils.ExternalInputDate(config.DIRECT_INDUSTRIES, date=self.date),
        }

    def build_segment(self, inputs, output_path):

        industry_to_segment_name_string = ',\n'.join(
            [u'AsTuple("{}", "{}")'.format(key, value) for key, value in industry_to_segment_name.iteritems()]
        )
        with self.yt.TempTable() as sample_table:

            self.yql.query(
                segment_query.format(
                    direct_users=inputs['DirectUsers'].table,
                    direct_industries=inputs['DirectIndustries'].table,
                    indevice_yandexuid_matching=config.INDEVICE_YANDEXUID,
                    industry_to_segment_name=industry_to_segment_name_string,
                    output_table=output_path,
                    sample_table=sample_table,
                ),
                transaction=self.transaction,
            )

            segment_name_to_id_dict = {
                'tourism': 1759,
                'medicine': 1758,
            }

            lals_params = list()
            for segment_name, segment_id in segment_name_to_id_dict.iteritems():
                lals_params.append(LaLParams(
                    name=segment_name,
                    id=segment_id,
                    type='lal_internal',
                    coverage=3000000,
                    include_input=False,
                ))

            self.prepare_samples_for_lal(
                input_table=sample_table,
                id_field='id',
                lals_params=lals_params,
        )
