import os
from textwrap import dedent
from datetime import datetime, timedelta
import calendar

from datacloud.config.yt import YT_PROXY
from datacloud.dev_utils.yql.yql_helpers import execute_yql
from datacloud.dev_utils.time.patterns import FMT_DATE

USER = os.environ.get('USER')
DEFAULT_SUFFIX = ''
CONFUSED_MARK = 'confused'


class FilterLogsReducer:
    def __init__(self, emb_of, good_eids, days_to_take):
        self.emb_of = emb_of
        self.good_eids = good_eids
        self.days_to_take = days_to_take

    def __call__(self, key, recs):
        if self.good_eids is not None and key['external_id'] not in self.good_eids:
            return
        retro_dt = datetime.strptime(key['external_id'].rsplit('_')[-1], FMT_DATE)
        retro_ts = calendar.timegm(retro_dt.utctimetuple())
        min_dt = retro_dt - timedelta(days=self.days_to_take)
        min_ts = calendar.timegm(min_dt.utctimetuple())

        prev_rec = {c: None for c in self.emb_of}
        for rec in recs:
            if min_ts <= rec['timestamp'] < retro_ts and any(rec[c] for c in self.emb_of) \
               and any(rec[c] != prev_rec[c] for c in self.emb_of):
                yield rec
                prev_rec = rec


apply_dssm_q = dedent("""
    PRAGMA yt.DataSizePerJob="25M";
    PRAGMA yt.DefaultMemoryLimit="15G";
    $dssm_model = Dssm::LoadModel(FilePath("model.dssm"));

    $table = (
        SELECT
            `external_id`,
            `yuid`,
            `timestamp`,
            `title` ?? "" as `title`,
            `url` ?? "" as `url`
        FROM `%(input_table)s`
    );

    INSERT INTO `%(output_table)s` WITH TRUNCATE
    select
        `external_id`,
        `yuid`,
        `timestamp`,
        `title`,
        `url`,
        Dssm::Apply(
            $dssm_model,
            AsStruct(%(cols_mapping)s),
            "%(layer_to_get)s"
        ) as `emb`
    FROM $table;
""")


def make_cols_mapping(dssm_cols_map):
    mapping_strs = []
    for map_from, map_to in dssm_cols_map:
        if map_from is None:
            map_from = '""'
        else:
            map_from = '`{}`'.format(map_from)
        mapping_strs.append('{} as {}'.format(map_from, map_to))
    return ', '.join(mapping_strs)


def make_cmd_cols_mapping(dssm_cols_map):
    columns, headers = [], []
    for map_from, map_to in dssm_cols_map:
        if map_from is None:
            columns.append('empty_column')
        else:
            columns.append(map_from)
        headers.append(map_to)

    res = []
    for col in columns:
        res.append('--column')
        res.append(col)
    for header in headers:
        res.append('--header')
        res.append(header)
    return res


def make_apply_cmd(input_table, applied_dssm, path_to_model, batch_size, layer_to_get, dssm_cols_map,
                   ds_per_job=25):
    cmd = [
        '/home/{}/arcadia/quality/nirvana_tools/conveyor_operations/train_dssm/dssm/yt_apply/yt_dssm_apply'.format(USER),
        '--proxy', YT_PROXY,
        '--src', input_table,
        '--dst', applied_dssm,
        '--model', path_to_model,
        '--result_column', 'emb',
        '--output_mode', 'OM_VECTOR',
        '--batch_size', str(batch_size),
        '--column2output', 'all_columns',
        '--model_output', layer_to_get,
        '--memory_limit', '18096',
        '--null_absent_as_empty',
        '--use_erasure_codec',
        '--model_type', 'dssm3',
        '--data_size_per_job', str(ds_per_job)
    ]
    cmd += make_cmd_cols_mapping(dssm_cols_map)

    return cmd


join_target_q = dedent("""
    INSERT INTO `%(tensors)s` WITH TRUNCATE
    SELECT
        tensors.*,
        CAST(eid_target.`%(target_name)s` as Int64) as `target`
    FROM (
        SELECT * WITHOUT `target`
        FROM `%(tensors)s`
    ) as tensors
    INNER JOIN `%(eid_target)s` as eid_target
    USING(`external_id`)
""")


def join_target(tensors, yt_client, yql_client, path_config, plconfig):
    execute_yql(query=join_target_q, yql_client=yql_client, params=dict(
        tensors=tensors,
        eid_target=path_config.eid_target,
        target_name=plconfig.target
    ), set_owners=False, syntax_version=1)


split_applied_q = dedent("""
    INSERT INTO `%(output_table)s` WITH TRUNCATE
    SELECT applied.*
    FROM `%(applied)s` as applied
    INNER JOIN `%(eid2mark)s` as eid2mark
    USING(`external_id`)
    WHERE eid2mark.`mark` == '%(mark)s'
    ORDER BY `external_id`, `timestamp`
""")


def suffixes_to_iterate_by(plconfig):
    if not plconfig.n_folds:
        return [DEFAULT_SUFFIX]

    return list(range(1, plconfig.n_folds + 1)) + [CONFUSED_MARK]
