#!/usr/bin/python
# -*- coding: utf-8 -*-
import argparse
import calendar
import logging
import os
import re
import sys
import time

import ads.libs.yql
from yt.wrapper import YtClient, YtHttpResponseError
from yt.transfer_manager.client import TransferManager
from datetime import datetime, timedelta
from bm_mr.resources import home
from contextlib import nested


POOL_BROADMATCHING = 'broadmatching'

NON_EXISTING_SD = 999999

BROAD_MATCH_TABLE_NAME = 'BroadMatchTable'
BROAD_MATCH_WITH_PHRASE_TABLE_NAME = 'BroadMatchWithPhraseTable'
BROAD_MINUSWORD_TABLE_NAME = 'BroadMinuswordTable'
BROAD_PHRASE_TABLE_NAME = 'BroadPhraseTable'
PROCESSED_BANNERS_TABLE_NAME = 'ProcessedBannersTable'
PUBLISH_INPUTS_SETTINGS = {
    'insight_dynamic_speedup': {
        'path': home.broadmatching.bmyt.publish.input.insight_dynamic_speedup.read(),
        'sd_written': True,
        'enabled': True,
    },
    'insight_synonym_speedup': {
        'path': home.broadmatching.bmyt.publish.input.insight_synonym_speedup.read(),
        'sd_written': True,
        'enabled': True,
    },
    'insight_synonym_308_speedup': {
        'path': home.broadmatching.bmyt.publish.input.insight_synonym_308_speedup.read(),
        'sd_written': True,
        'enabled': True,
    },
    'insight_synonym_exp_speedup': {
        'path': home.broadmatching.bmyt.publish.input.insight_synonym_exp_speedup.read(),
        'sd_written': True,
        'enabled': True,
    },
    'insight_synonym_exp_1408_speedup': {
        'path': home.broadmatching.bmyt.publish.input.insight_synonym_exp_1408_speedup.read(),
        'sd_written': True,
        'enabled': True,
    },
    'insight_synonym_exp_1508_speedup': {
        'path': home.broadmatching.bmyt.publish.input.insight_synonym_exp_1508_speedup.read(),
        'sd_written': True,
        'enabled': True,
    },
    'insight_synonym_exp_1608_speedup': {
        'path': home.broadmatching.bmyt.publish.input.insight_synonym_exp_1608_speedup.read(),
        'sd_written': True,
        'enabled': True,
    },
    'insight_synonym_exp_1708_speedup': {
        'path': home.broadmatching.bmyt.publish.input.insight_synonym_exp_1708_speedup.read(),
        'sd_written': True,
        'enabled': True,
    },
    'insight_synonym_exp_1808_speedup': {
        'path': home.broadmatching.bmyt.publish.input.insight_synonym_exp_1808_speedup.read(),
        'sd_written': True,
        'enabled': True,
    },
    'insight_synonym_exp_2008_speedup': {
        'path': home.broadmatching.bmyt.publish.input.insight_synonym_exp_2008_speedup.read(),
        'sd_written': True,
        'enabled': True,
    },
}
TM_PARAMS = {
    'copy_spec': {'pool': 'broadmatching'},
    'postprocess_spec': {'pool': 'broadmatching'},
    'queue_name': 'broadmatching',
}


logger = logging.getLogger(__name__)


def yt_time2ts(yt_time_str):
    return calendar.timegm(datetime.strptime(yt_time_str, "%Y-%m-%dT%H:%M:%S.%fZ").timetuple())


def run_yql(
        src_clus, src_client, publish_inputs, bm_temp_table, bmwp_temp_table, bmw_temp_table, bp_temp_table, bm_prev_table, transaction_id
):
    bmw_tables_str = 'concat({})'.format(','.join(['`' + x["path"] + '/' + BROAD_MINUSWORD_TABLE_NAME + '`' for x in publish_inputs]))
    bp_tables_str = 'concat({})'.format(','.join(['`' + x["path"] + '/' + BROAD_PHRASE_TABLE_NAME + '`' for x in publish_inputs]))

    header = '''
    PRAGMA yt.DataSizePerJob = "128000000";


    -- Selects SD to table ID mapping

    $select_tid = Python2::select_tid(
        ParseType("(List<Tuple<List<Int32>?, Int64?, Int32>>) -> Dict<Uint64, Int32>"),
        @@
def select_tid(maps):
    sd_max_ts = dict()
    sd_table = dict()
    for map in maps:
        for sd in map[0]:
            if sd not in sd_max_ts or sd_max_ts[sd] < map[1]:
                sd_table[sd] = map[2]
                sd_max_ts[sd] = map[1]
    return sd_table
        @@
    )
    ;


    '''

    tid_mask = '''
    $tid_mask = (
    SELECT BannerID, $select_tid(AGGREGATE_LIST(SDMap)) AS SDMap
    FROM (
    '''

    for pi in range(len(publish_inputs)):
        if pi > 0:
            tid_mask += "\nUNION ALL\n"

        if publish_inputs[pi]["sd_written"]:
            tid_mask += '''
            SELECT BannerID, AsTuple(AsList(SimDistance), `TimeStamp`, {input_id}) AS SDMap
            FROM `{input_path}/{processed_name}`
            WHERE BannerID != 0
            '''.format(
                input_path=publish_inputs[pi]["path"],
                input_id=pi,
                processed_name=PROCESSED_BANNERS_TABLE_NAME
            )
        else:
            sd_list = src_client.get_attribute(publish_inputs[pi]["path"], "SimDistances")
            sd_list_str = 'AsList(' + ', '.join(str(x) for x in sd_list) + ')'
            tid_mask += '''
            SELECT BannerID, AsTuple(
                {sd_list},
                `TimeStamp`,
                {input_id}
            ) AS SDMap
            FROM `{input_path}/{processed_name}`
            WHERE BannerID != 0
            '''.format(
                input_path=publish_inputs[pi]["path"],
                input_id=pi,
                sd_list=sd_list_str,
                processed_name=PROCESSED_BANNERS_TABLE_NAME,
            )

    tid_mask += '''
    )
    GROUP BY BannerID
    );
    '''

    bm_data = '''
    $bm_data = (
    '''
    for pi in range(len(publish_inputs)):
        if pi > 0:
            bm_data += "\nUNION ALL\n"
        bm_data += '''
        SELECT
            CAST(APCRatio AS int64) as APCRatio,
            T.BannerID AS BannerID,
            Just(BroadPhraseID) AS BroadPhraseID,
            CAST(CTR AS int64) AS CTR,
            Just(CAST(ContextType AS int64)) AS ContextType,
            CAST(Options AS uint64) AS Options,
            CAST(PCTR AS int64) AS PCTR,
            PhraseID,
            Score,
            CAST(SimDistance AS uint64) AS SimDistance,
            `TimeStamp`,
            OrderID
        FROM `{input_path}/{bm_name}` AS T
        JOIN $tid_mask AS TM
            ON TM.BannerID = T.BannerID
        WHERE TM.SDMap[T.SimDistance] = {input_id}

        '''.format(
            input_path=publish_inputs[pi]["path"],
            input_id=pi,
            bm_name=BROAD_MATCH_TABLE_NAME,
        )
    bm_data += '''
    );
    '''

    # Turn on nontrivial PrevScore only in separated SimDistances
    if src_client.exists(bm_prev_table):
        build_bm_table = '''
        $prev_data = (
            SELECT BannerID, BroadPhraseID, PhraseID,
                SimDistance, ContextType,
                MAX_BY(AsTuple(PrevScore, UpdateTime, Score), UpdateTime) AS Prev
            FROM `{bm_prev_table}`
            GROUP BY BannerID, BroadPhraseID, PhraseID, SimDistance, ContextType
        );

        insert into `{bm_temp_table}` with truncate
        select
            CTR,
            Score,
            PCTR,
            BM.PhraseID AS PhraseID,
            APCRatio,
            BM.SimDistance AS SimDistance,
            BM.ContextType AS ContextType,
            BM.BroadPhraseID AS BroadPhraseID,
            Options,
            OrderID,
            BM.BannerID AS BannerID,
            CAST(`TimeStamp` AS UINT64) AS UpdateTime,
            Score AS PrevScore
        from $bm_data AS BM

        LEFT JOIN $prev_data AS EXBM
        ON EXBM.BannerID == BM.BannerID
            AND EXBM.BroadPhraseID == BM.BroadPhraseID
            AND EXBM.PhraseID == BM.PhraseID
            AND EXBM.SimDistance == BM.SimDistance
            AND EXBM.ContextType == BM.ContextType
        ;
        '''.format(
            bm_prev_table=bm_prev_table,
            bm_temp_table=bm_temp_table
        )
    else:
        build_bm_table = '''
        insert into `{bm_temp_table}` with truncate
        select
            CTR,
            Score,
            PCTR,
            PhraseID,
            APCRatio,
            SimDistance,
            ContextType,
            BroadPhraseID,
            Options,
            BannerID,
            CAST(`TimeStamp` AS UINT64) as UpdateTime,
            Score AS PrevScore,
            OrderID
        from $bm_data
        ;
        '''.format(
            bm_temp_table=bm_temp_table
        )

    query = header + tid_mask + bm_data + build_bm_table + '''

    insert into `{bmw_temp_table}` with truncate
    select
        Just(BannerID) as BannerID,
        Just(MAX_BY(Data, `TimeStamp`)) AS Data,
        MAX_BY(OrderID, `TimeStamp`) AS OrderID
    from {bmw_tables_str}
    GROUP BY BannerID
    ;

    insert into `{bp_temp_table}` with truncate
    select distinct
        Just(BroadPhraseID) as BroadPhraseID,
        NormType,
        Data
    from {bp_tables_str}
    order by BroadPhraseID, NormType
    ;

    insert into `{bmwp_temp_table}` with truncate
    select
        bm.*,
        NormType,
        Data
    from (
        select
            CTR,
            Score,
            PCTR,
            PhraseID,
            APCRatio,
            SimDistance,
            ContextType,
            BroadPhraseID,
            Options,
            BannerID,
            CAST(`TimeStamp` AS UINT64) as UpdateTime,
            Score AS PrevScore,
            OrderID
        from $bm_data
    ) as bm
    inner join (
        select
            BroadPhraseID,
            min(NormType) as NormType,
            min(Data) as Data
        from {bp_tables_str}
        group by BroadPhraseID
    ) as bp
    using (BroadPhraseID);
    ;

    '''.format(
        bmw_temp_table=bmw_temp_table, bp_temp_table=bp_temp_table,
        bmw_tables_str=bmw_tables_str, bp_tables_str=bp_tables_str,
        bmwp_temp_table=bmwp_temp_table,
    )

    if not ads.libs.yql.run_yql_query(query=query, db=src_clus, pool=POOL_BROADMATCHING, transaction_id=transaction_id).get_results().is_success:
        raise Exception("YQL query failed")


def build_broad_phrase_delta(yt_client, yt_cluster, bp_prev_table, bp_new_table, target_dir):

    # No old BroadPhrase table? Let the diff be empty
    if not yt_client.exists(bp_prev_table):
        bp_prev_table = bp_new_table

    timestamps = [yt_client.get_attribute(x, 'modification_time') for x in [bp_prev_table, bp_new_table]]
    timestamps = "_".join(str(yt_time2ts(x)) for x in timestamps)

    target_table = "/".join([target_dir, timestamps])
    query = '''
    PRAGMA yt.ExpirationInterval = "14d";
    INSERT INTO `{target_table}`
    WITH TRUNCATE
    SELECT
        NEW.BroadPhraseID AS BroadPhraseID,
        NEW.NormType AS NormType,
        NEW.Data AS Data
    FROM (
        SELECT BroadPhraseID, NormType
        FROM `{prev_table}`
    ) AS PREV
    RIGHT ONLY JOIN (
        SELECT BroadPhraseID, NormType, Data
        FROM `{new_table}`
    ) AS NEW
    ON PREV.BroadPhraseID == NEW.BroadPhraseID
        AND PREV.NormType == NEW.NormType
    ;
    '''.format(
        target_table=target_table,
        prev_table=bp_prev_table, new_table=bp_new_table
    )

    if not ads.libs.yql.run_yql_query(query=query, db=yt_cluster, pool=POOL_BROADMATCHING).get_results().is_success:
        raise Exception("YQL query failed")

    return target_table


def are_new_tables_valid(
        src_clus, src_client,
        bm_new_table, bmw_new_table, bp_new_table,
        bm_old_table, bmw_old_table, bp_old_table,
        skip_check_for, skip_schema_check, skip_pretty_check_for
):
    if not skip_schema_check:
        for tbls in [(bm_new_table, bm_old_table), (bmw_new_table, bmw_old_table), (bp_new_table, bp_old_table)]:
            (src_tbl, dst_tbl) = tbls
            if not src_client.exists(dst_tbl):
                continue
            src_schema = src_client.get_attribute(src_tbl, 'schema')
            dst_schema = src_client.get_attribute(dst_tbl, 'schema')
            if sorted(src_schema) != sorted(dst_schema):
                print >> sys.stderr, "Schema is incorrect:\n{}\nShould be:\n{}\n".format(
                    str(src_schema), str(dst_schema))
                return False

    #  TODO: Get all fields from schema?
    almost_all_fields = (
        'APCRatio',
        'BannerID',
        'BroadPhraseID',
        #  'ContextType',
        'CTR',
        'Options',
        'PCTR',
        'PhraseID',
        'PrevScore',
        'Score',
        #  'SimDistance',
        'UpdateTime',
    )
    num_fields = (
        'APCRatio',
        'CTR',
        'PCTR',
        'PrevScore',
        'Score',
    )

    query = '''
$by_sd_ct_old = (
    SELECT
        SimDistance, ContextType,
        COUNT(*) AS Total
    FROM `{bm_old_table}`
    GROUP BY SimDistance, ContextType
);

$by_sd_ct = (
    SELECT
        SimDistance, ContextType,
        COUNT(*) AS Total,
        COUNT_IF(0 == (Score % 100000)) AS RoundScoreCount,
        COUNT_IF(0 != (Score % 100000)) AS NonRoundScoreCount,
        MIN(Score) AS MinScore,
        MAX(Score) AS MaxScore,
    ''' + ",\n    ".join("COUNT_IF(" + f + " IS NULL) AS " + f + "Null" for f in almost_all_fields) + ''',
    ''' + ",\n    ".join("COUNT_IF(" + f + " > 1000000) AS " + f + "TooBig" for f in num_fields) + ''',
        COUNT_IF(UpdateTime > 2147483647) AS UpdateTimeTooBig
    FROM `{bm_new_table}`
    GROUP BY SimDistance, ContextType
);

$by_sd = (
    SELECT
        SimDistance,
        SUM(Total) AS Total,
        SUM(RoundScoreCount) AS RoundScoreCount,
        SUM(NonRoundScoreCount) AS NonRoundScoreCount,
        MIN(MinScore) AS MinScore,
        MAX(MinScore) AS MaxScore
    FROM $by_sd_ct
    GROUP BY SimDistance
);

$totals = (
    SELECT
        SUM(Total) ?? 0 AS Total,
    ''' + ",\n    ".join("SUM(" + f + "Null) ?? 0 AS " + f + "Null" for f in almost_all_fields) + ''',
        SUM_IF(Total, ContextType IS NULL) ?? 0 AS ContextTypeNull,
        SUM_IF(Total, SimDistance IS NULL) ?? 0 AS SimDistanceNull,
    ''' + ",\n    ".join("SUM(" + f + "TooBig) ?? 0 AS " + f + "TooBig" for f in num_fields) + ''',
        SUM(UpdateTimeTooBig) ?? 0 AS UpdateTimeTooBig
    FROM $by_sd_ct
);

    $Grows = 7.5;
    select
        Ensure(2147483647, 2147483647 between 1.0*1000*1000*1000 and $Grows*1000*1000*1000, 'Insanity check'),
        Ensure(2147483648, 2147483648 between 1.0*1000*1000*1000 and $Grows*1000*1000*1000, 'Insanity check'),
        Ensure(Total, Total between 1.0*1000*1000*1000 and $Grows*1000*1000*1000, 'Rows count: ' || CAST(Total AS String))
    from $totals
    ;

    select SimDistance, Ensure(RoundScoreCount * 100.0 / Total, (RoundScoreCount * 100.0 / Total) < 2.0, 'bm score too pretty in SD=' || cast(SimDistance ?? 0 as string))
        -- ToDo: restore 0.005 value. Ratio of Round scores should be less then 0.005%
    from $by_sd
    where SimDistance < 100000 and SimDistance not in ({skip_pretty_check_for})
    ;

    SELECT SimDistance,
        Mins * 100. / Total,
        Ensure(
            Maxs * 100. / Total,
            MinScore == MaxScore OR (
                Mins * 100. / Total < 1.5 AND Maxs * 100. / Total < 1.5 -- ToDo: restore 0.0005 value
            ),
            'Bm score on bounds SD=' || CAST(SimDistance ?? 0 AS STRING)
        )
    FROM (
    SELECT BM.SimDistance AS SimDistance,
       COUNT_IF(Score == MinScore) AS Mins,
       COUNT_IF(Score == MaxScore) AS Maxs,
       SOME(Bounds.MinScore) ?? 0 AS MinScore,
       SOME(Bounds.MaxScore) ?? 0 AS MaxScore,
       COUNT(*) AS Total
    FROM `{bm_new_table}` AS BM
    JOIN $by_sd AS Bounds ON BM.SimDistance = Bounds.SimDistance
    GROUP BY BM.SimDistance
    );

    SELECT
    ''' + ",\n    ".join("Ensure(" + f + "Null, 0 == " + f + "Null, '" + f + " null')" for f in almost_all_fields) + ''',
        Ensure(ContextTypeNull, 0 == ContextTypeNull, 'ContextType null'),
        Ensure(SimDistanceNull, 0 == SimDistanceNull, 'SimDistance null'),
    ''' + ",\n    ".join("Ensure(" + f + "TooBig, 0 == " + f + "TooBig, '" + f + " too big')" for f in num_fields) + ''',
        Ensure(UpdateTimeTooBig, 0 == UpdateTimeTooBig, 'UpdateTime too big')
    FROM $totals
    ;

    select Ensure(count(*), count(*) == 0, 'bmw null')
    from `{bmw_new_table}`
    where BannerID is null or Data is null;

    select Ensure(count(*), count(*) == 0, 'bp null')
    from `{bp_new_table}`
    where Data is null or NormType is null or BroadPhraseID is null;

    '''

    if src_client.exists(bm_old_table):
        query += '''
        select ContextType, SimDistance, OldTotal, NewTotal, Ensure(
            NewTotal/OldTotal,
            ((ContextType, SimDistance) in ({skip_check_for})) or --YQL injection here
            (NewTotal/OldTotal between 0.85 and 1.3),
            'New to prev ratio CT=' || cast(ContextType ?? 0 as string) || ' SD=' || cast(SimDistance ?? 0 as string) ||
            ' is ' || cast((NewTotal ?? 0)/(OldTotal ?? 0) as string)
        )
        from (
            select
                OLD.SimDistance ?? NEW.SimDistance AS SimDistance,
                OLD.ContextType ?? NEW.ContextType AS ContextType,
                cast(OLD.Total ?? 1 AS double) AS OldTotal,
                cast(NEW.Total ?? 0 AS double) AS NewTotal
            from $by_sd_ct AS NEW
            full join $by_sd_ct_old AS OLD
                on NEW.SimDistance == OLD.SimDistance and NEW.ContextType == OLD.ContextType
        )
        '''

    query = query.format(
        skip_check_for=skip_check_for, skip_pretty_check_for=skip_pretty_check_for,
        bm_new_table=bm_new_table, bmw_new_table=bmw_new_table, bp_new_table=bp_new_table,
        bm_old_table=bm_old_table, bmw_old_table=bmw_old_table, bp_old_table=bp_old_table
    )

    if not ads.libs.yql.run_yql_query(query=query, db=src_clus, pool=POOL_BROADMATCHING).get_results().is_success:
        return False
    return True


def are_inputs_updated(src_client, bm_tables, bmw_tables, bp_tables, bm_dst_table, bmw_dst_table, bp_dst_table):
    for pair in [(bm_tables, bm_dst_table), (bmw_tables, bmw_dst_table), (bp_tables, bp_dst_table)]:
        (srcs, dst) = pair
        cur_timestamps = {}
        try:
            if src_client.exists(dst):
                cur_timestamps = src_client.get_attribute(dst, 'timestamps')
            else:
                return True
        except YtHttpResponseError as exc:
            if 'Error resolving path' not in str(exc) and not re.match("Attribute .* is not found", str(exc)):
                raise
        timestamps = {x: src_client.get_attribute(x, 'modification_time') for x in srcs}
        if 0 != cmp(timestamps, cur_timestamps):
            return True
    return False


def calc_next_from_publish_input(src_clus, publish_inputs, next_dir, preprod_dir, force):
    recalc_stamp = ":".join(os.uname() + (str(time.time()),))
    src_client = YtClient(proxy=src_clus, token=os.environ.get('YT_TOKEN'))

    bm_tables = ['{}/{}'.format(x["path"], BROAD_MATCH_TABLE_NAME) for x in publish_inputs]
    bmw_tables = ['{}/{}'.format(x["path"], BROAD_MINUSWORD_TABLE_NAME) for x in publish_inputs]
    bp_tables = ['{}/{}'.format(x["path"], BROAD_PHRASE_TABLE_NAME) for x in publish_inputs]

    next_bm_table = '{}/{}'.format(next_dir, BROAD_MATCH_TABLE_NAME)
    next_bmwp_table = '{}/{}'.format(next_dir, BROAD_MATCH_WITH_PHRASE_TABLE_NAME)
    next_bmw_table = '{}/{}'.format(next_dir, BROAD_MINUSWORD_TABLE_NAME)
    next_bp_table = '{}/{}'.format(next_dir, BROAD_PHRASE_TABLE_NAME)

    preprod_bm_table = '{}/{}'.format(preprod_dir, BROAD_MATCH_TABLE_NAME)

    # Check there is some fresh data in inputs
    if (not force) and not are_inputs_updated(
            src_client=src_client,
            bm_tables=bm_tables, bmw_tables=bmw_tables, bp_tables=bp_tables,
            bm_dst_table=next_bm_table, bmw_dst_table=next_bmw_table, bp_dst_table=next_bp_table
    ):
        print >> sys.stderr, 'Nothing to do'
        return

    # Build final data from inputs
    with src_client.Transaction() as tx:
        run_yql(
            src_clus=src_clus, src_client=src_client,
            publish_inputs=publish_inputs,
            bm_temp_table=next_bm_table,
            bmwp_temp_table=next_bmwp_table,
            bmw_temp_table=next_bmw_table,
            bp_temp_table=next_bp_table,
            bm_prev_table=preprod_bm_table,
            transaction_id=tx.transaction_id
        )

        # Set timesteps to check data freshness later
        # set to temp tables because they will be moved to dst later
        for pair in [(bm_tables, next_bm_table), (bm_tables, next_bmwp_table), (bmw_tables, next_bmw_table), (bp_tables, next_bp_table)]:
            (srcs, dst) = pair
            timestamps = {x: src_client.get_attribute(x, 'modification_time') for x in srcs}
            src_client.set_attribute(dst, 'timestamps', timestamps)
            src_client.set_attribute(dst, 'recalc_stamp', recalc_stamp)
            src_client.set_attribute(dst, 'suppress_nightly_merge', True)


def srcs_exist_and_attr_ne(client, srcs, dsts, attr):
    exists = all([client.exists(src) for src in srcs])
    try:
        return exists and any([
            not client.exists(dst) or
            0 != cmp(client.get_attribute(src, attr), client.get_attribute(dst, attr))
            for (src, dst) in zip(srcs, dsts)
        ])
    except YtHttpResponseError as exc:
        if not re.match("Attribute .* is not found", str(exc)):
            raise
    return exists


def wait_transfer_manager_tasks(tm_client, task_ids):
    while task_ids:
        logging.info('======')
        for task_id in task_ids:
            info = tm_client.get_task_info(task_id)
            if 'finish_time' in info:
                logging.info('task %s done: %s' % (task_id, info['state']))
                task_ids = [t for t in task_ids if t != task_id]
            else:
                logging.info('task %s runs: %s' % (task_id, info['state']))
        time.sleep(10)


# ToDo: fix case of first execution without --noprod option
# Run checks on src cluster and copy src to dst with preserving dst in dst_old
# Moves next -> preprod -> prod
def move_next_preprod_prod(
        src_clus, dst_clusters, next_dir, preprod_dir, prod_dir, bp_deltas_dir, skip_check_for, skip_schema_check,
        skip_time_check, force, no_next_preprod, no_preprod_prod, skip_pretty_check_for):
    src_client = YtClient(proxy=src_clus, token=os.environ.get('YT_TOKEN'))
    clients = {src_clus: src_client}
    for dst_clus in dst_clusters:
        clients[dst_clus] = YtClient(proxy=dst_clus, token=os.environ.get('YT_TOKEN'))

    next_bm_table = '{}/{}'.format(next_dir, BROAD_MATCH_TABLE_NAME)
    next_bmwp_table = '{}/{}'.format(next_dir, BROAD_MATCH_WITH_PHRASE_TABLE_NAME)
    next_bmw_table = '{}/{}'.format(next_dir, BROAD_MINUSWORD_TABLE_NAME)
    next_bp_table = '{}/{}'.format(next_dir, BROAD_PHRASE_TABLE_NAME)
    next_tables = [next_bm_table, next_bmwp_table, next_bmw_table, next_bp_table]

    preprod_bm_table = '{}/{}'.format(preprod_dir, BROAD_MATCH_TABLE_NAME)
    preprod_bmwp_table = '{}/{}'.format(preprod_dir, BROAD_MATCH_WITH_PHRASE_TABLE_NAME)
    preprod_bmw_table = '{}/{}'.format(preprod_dir, BROAD_MINUSWORD_TABLE_NAME)
    preprod_bp_table = '{}/{}'.format(preprod_dir, BROAD_PHRASE_TABLE_NAME)
    preprod_tables = [preprod_bm_table, preprod_bmwp_table, preprod_bmw_table, preprod_bp_table]

    prod_bm_table = '{}/{}'.format(prod_dir, BROAD_MATCH_TABLE_NAME)
    prod_bmwp_table = '{}/{}'.format(prod_dir, BROAD_MATCH_WITH_PHRASE_TABLE_NAME)
    prod_bmw_table = '{}/{}'.format(prod_dir, BROAD_MINUSWORD_TABLE_NAME)
    prod_bp_table = '{}/{}'.format(prod_dir, BROAD_PHRASE_TABLE_NAME)
    prod_tables = [prod_bm_table, prod_bmwp_table, prod_bmw_table, prod_bp_table]

    # Collect revisions of all tables %)
    revisions = {src_clus: {}}
    for dst_clus in dst_clusters:
        revisions[dst_clus] = {}
    for clus in revisions.keys():
        for table in next_tables + preprod_tables + prod_tables:
            if clients[clus].exists(table):
                revisions[clus][table] = clients[clus].get_attribute(table, 'revision')
    next_bm_table_recalc_stamp = src_client.get_attribute(next_bm_table, 'recalc_stamp')

    # Check preprod data is not too fresh
    if not skip_time_check:
        max_creation_time = max([
            datetime.strptime(clients[clus].get_attribute(table, 'creation_time'), '%Y-%m-%dT%H:%M:%S.%fZ')
            for table in preprod_tables
            for clus in [src_clus] + dst_clusters
        ])
        if datetime.utcnow() - timedelta(hours=1) < max_creation_time:
            print >> sys.stderr, 'Too early'
            return

    # Check what to move
    unexpected_table_contents = False
    move_next_preprod = (force or srcs_exist_and_attr_ne(src_client, next_tables, preprod_tables, 'timestamps')) and not no_next_preprod
    move_preprod_prod = (force or srcs_exist_and_attr_ne(src_client, preprod_tables, prod_tables, 'timestamps')) and not no_preprod_prod
    print >> sys.stderr, 'move_next_preprod {}, move_preprod_prod {}'.format(move_next_preprod, move_preprod_prod)
    if not (move_next_preprod or move_preprod_prod):
        return

    # check, build delta, transfer
    tm_client = TransferManager(http_request_timeout=100000, retry_count=12)  # Increase timeout 10x, retry_count 2x
    task_ids = []

    if move_next_preprod:
        # Let's be optimistic and transfer asynchronously during check
        task_ids.extend([
            tm_client.add_task(src_clus, table, dst_clus, table, sync=False, params=TM_PARAMS)
            for table in next_tables
            for dst_clus in dst_clusters
        ])
        # Check new data does not differ too much
        if not are_new_tables_valid(
                src_clus=src_clus, src_client=src_client,
                bm_new_table=next_bm_table, bmw_new_table=next_bmw_table, bp_new_table=next_bp_table,
                bm_old_table=preprod_bm_table, bmw_old_table=preprod_bmw_table, bp_old_table=preprod_bp_table,
                skip_check_for=skip_check_for, skip_schema_check=skip_schema_check, skip_pretty_check_for=skip_pretty_check_for,
        ):
            move_next_preprod = False
            unexpected_table_contents = True

    if move_preprod_prod:
        bp_delta_table = build_broad_phrase_delta(
            yt_client=src_client, yt_cluster=src_clus,
            bp_prev_table=prod_bp_table, bp_new_table=preprod_bp_table,
            target_dir=bp_deltas_dir
        )
        tm_params = TM_PARAMS.copy()
        tm_params.update({
            'additional_attributes': ['expiration_time']
        })
        for dst_clus in dst_clusters:
            task_ids.append(tm_client.add_task(src_clus, bp_delta_table, dst_clus, bp_delta_table, sync=False, params=tm_params))

    wait_transfer_manager_tasks(tm_client, task_ids)

    # Copy data on both production clusters
    with nested(*[client.Transaction() for client in clients.values()]):
        if move_preprod_prod:
            # check preprod tables
            for clus in [src_clus] + dst_clusters:
                for table in preprod_tables + prod_tables:
                    if table in revisions[clus]:
                        current_revision = clients[clus].get_attribute(table, 'revision')
                        descr = '{} {} revision: was: {}; now: {}'.format(clus, table, revisions[clus][table], current_revision)
                        assert revisions[clus][table] == current_revision, descr
                if clients[clus].exists(prod_dir):
                    prod_dir_old = prod_dir + '_old'
                    print "copy on {} from {} to {}".format(clus, prod_dir, prod_dir_old)
                    clients[clus].copy(prod_dir, prod_dir_old, force=True)
                print "copy on {} from {} to {}".format(clus, preprod_dir, prod_dir)
                clients[clus].copy(preprod_dir, prod_dir, force=True)

        if move_next_preprod:
            # check next tables src cluster
            for table in next_tables:
                current_revision = src_client.get_attribute(table, 'revision')
                descr = '{} revision: was: {}, now: {}'.format(table, revisions[src_clus][table], current_revision)

                assert revisions[src_clus][table] == src_client.get_attribute(table, 'revision'), descr
            # check tables on both clusters
            for clus in [src_clus] + dst_clusters:
                for table in next_tables:
                    current_stamp = clients[clus].get_attribute(table, 'recalc_stamp')
                    descr = '{} {} recalc_stamp: was {}, now: {}'.format(clus, table, next_bm_table_recalc_stamp, current_stamp)
                    assert next_bm_table_recalc_stamp == current_stamp, descr
                for table in preprod_tables:
                    if table in revisions[clus]:
                        current_revision = clients[clus].get_attribute(table, 'revision')
                        descr = '{} {} revision: was: {}, now {}'.format(clus, table, revisions[clus][table], current_revision)
                        assert revisions[clus][table] == current_revision, descr
                if clients[clus].exists(preprod_dir):
                    preprod_dir_old = preprod_dir + '_old'
                    print "copy on {} from {} to {}".format(clus, preprod_dir, preprod_dir_old)
                    clients[clus].copy(preprod_dir, preprod_dir_old, force=True)
                print "copy on {} from {} to {}".format(clus, next_dir, preprod_dir)
                clients[clus].copy(next_dir, preprod_dir, force=True)

    if unexpected_table_contents:
        raise Exception('unexpected table contents')


def input2settings(input_name):
    if input_name in PUBLISH_INPUTS_SETTINGS:
        return PUBLISH_INPUTS_SETTINGS[input_name]
    else:
        return {
            'path': input_name,
            'sd_written': True,
        }


def main():
    argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    argparser.add_argument('--src', type=str, required=True)  # src cluster
    argparser.add_argument('--dsts', type=str, action='append', required=True)  # dst clusters
    argparser.add_argument('--next-dir', type=str, default=home.bs.broad_match.next.write())
    argparser.add_argument('--preprod-dir', type=str, default=home.bs.broad_match.v8.write())
    argparser.add_argument('--prod-dir', type=str, default=home.bs.broad_match.v7.write())
    argparser.add_argument('--bp-deltas-dir', type=str)
    argparser.add_argument('--override-publish-inputs', type=str)
    argparser.add_argument('--skip-check-for', type=str)
    argparser.add_argument('--skip-pretty-check-for', type=str)
    argparser.add_argument('--skip-schema-check', action='store_true')
    argparser.add_argument('--skip-time-check', action='store_true')
    argparser.add_argument('--force', action='store_true')
    argparser.add_argument('--no-move', action='store_true')
    argparser.add_argument('--no-calc', action='store_true')
    argparser.add_argument('--no-preprod-prod', action='store_true')
    argparser.add_argument('--no-next-preprod', action='store_true')
    args = argparser.parse_args()

    # ToDo: Fix top formula
    default_skip_pairs = [
        '(7,700330)',
        '(7,700333)',
        '(8,800333)',
        '(1,1202)',
        '(11,100314)',
        '(11,100317)',
        '(11,100417)',
    ]
    default_skip_check_for = ','.join(sorted(list(set(default_skip_pairs))))
    skip_check_for = default_skip_check_for + ',' + args.skip_check_for if args.skip_check_for else default_skip_check_for
    print "Going to skip check for '{}'".format(skip_check_for)
    assert re.match('^[()\d, ]+$', skip_check_for), 'skip-check-for format: [{}]'.format(skip_check_for)

    skip_pretty_check_for = args.skip_pretty_check_for if args.skip_pretty_check_for else str(NON_EXISTING_SD)
    print "Going to skip pretty check for '{}'".format(skip_pretty_check_for)
    assert re.match('^[\d, ]+$', skip_pretty_check_for), 'skip-pretty-check-for format: [{}]'.format(skip_pretty_check_for)

    publish_inputs = filter(lambda k: PUBLISH_INPUTS_SETTINGS[k]['enabled'], PUBLISH_INPUTS_SETTINGS.keys())
    publish_inputs = args.override_publish_inputs.split(',') if args.override_publish_inputs else publish_inputs
    publish_inputs = map(input2settings, publish_inputs)

    next_dir = args.next_dir
    preprod_dir = args.preprod_dir
    prod_dir = args.prod_dir

    if not args.no_move:
        print "move_next_preprod_prod to {} and {}".format(preprod_dir, prod_dir)
        assert args.bp_deltas_dir is not None
        assert prod_dir is not None
        move_next_preprod_prod(
            src_clus=args.src, dst_clusters=args.dsts,
            next_dir=next_dir, prod_dir=prod_dir, preprod_dir=preprod_dir,
            bp_deltas_dir=args.bp_deltas_dir,
            skip_check_for=skip_check_for, skip_schema_check=args.skip_schema_check,
            skip_time_check=args.skip_time_check, force=args.force,
            no_next_preprod=args.no_next_preprod, no_preprod_prod=args.no_preprod_prod,
            skip_pretty_check_for=skip_pretty_check_for,
        )
    if not args.no_calc:
        print "calc_next_from_publish_input to {}".format(next_dir)
        calc_next_from_publish_input(
            src_clus=args.src,
            publish_inputs=publish_inputs,
            next_dir=next_dir, preprod_dir=preprod_dir,
            force=args.force,
        )


if __name__ == '__main__':
    main()
