#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Site2vec builder

Usage:
  vectors_learning.py download [--last-n-days=<int>] [--with_apps=<bool>]
  vectors_learning.py train [--last-n-days=<int>] [--min-site-occurrence-count=<int>] [--cbow=<int>] [--sample=<float>] [--hs=<int>] [--negative=<int>] [--window=<int>] [--vector_size=<int>] [--alpha=<float>] [--use-start-vectors=<bool>] [--word2vec_debug=<int>]
  vectors_learning.py upload [--last-n-days=<int>] [--replace-old-vectors=<bool>]
  vectors_learning.py download_train_upload [--last-n-days=<int>] [--with_apps=<bool>] [--min-site-occurrence-count=<int>] [--cbow=<int>] [--sample=<float>] [--hs=<int>] [--negative=<int>] [--window=<int>] [--vector-size=<int>] [--alpha=<float>] [--use-start-vectors=<bool>] [--word2vec_debug=<int>] [--replace-old-vectors=<bool>]
  vectors_learning.py -h | --help

Options:
  download                             Download user sessions
  train                                Train site2vec app2vec embeddings
  upload                               Upload site2vec app2vec embeddings
  download_train_upload                Perform full cycle of site vectors building
  --last-n-days=<int>                  Download only last n days [default: 21]
  --with_apps=<bool>                    Learning with apps? True or False [default: False]
  --min-site-occurrence-count=<int>    Min site occurrence among all sessions [default: 50]
  --cbow=<int>                         Use the continuous bag of words model; default is skip-gram model [default: 0]
  --sample=<float>                     Threshold for high frequency sites down-sampling; useful value is 1e-5 [default: 0]
  --hs=<int>                           Use Hierarchical Softmax; 0 = not used [default: 0]
  --negative=<int>                     Number of negative examples; common values are 5 - 10 [default: 10]
  --window=<int>                       Context widow size [default: 10]
  --vector-size=<int>                  Size of site vectors [default: 512]
  --alpha=<float>                      Training learning rate [default: 0.025]
  --use-start-vectors=<bool>           Is there start vectors? True or False [default: True]
  --word2vec_debug=<int>               Level of logging for word2vec [default: 2]
  --replace-old-vectors=<bool>         Replace old vectors with new ones? True or False [default: False]
"""  # noqa

from collections import Counter
import datetime
from functools import partial
import logging
import multiprocessing
import os
import random
import time

from docopt import docopt
import numpy as np
from retry import retry

from yt.wrapper.errors import YtHttpResponseError
from yt.wrapper import with_context, create_table_switch

from crypta.profile.lib import vector_helpers

from crypta.profile.utils.config import config
from crypta.profile.utils.yt_utils import get_yt_client
from crypta.profile.utils.yql_utils import query as yql_query


GOOD_DISTANCE_BETWEEN_NEW_AND_OLD_VECTORS = 0.25
ACCEPTABLE_RATE_OF_GOOD_HOSTS = 0.8


logger = logging.getLogger(__name__)

get_yandexuid_apps_query_template = """
$apps_cryptaid = (
    SELECT app_metrica_month.apps AS apps, devid_cryptaid.crypta_id AS crypta_id
    FROM `{app_metrica_month}` AS app_metrica_month
    INNER JOIN `{devid_cryptaid}` AS devid_cryptaid
    ON app_metrica_month.id == devid_cryptaid.devid
    WHERE ListLength(app_metrica_month.apps) > 0 AND ListLength(app_metrica_month.apps) <= 200
);

$apps_yandexuid_raw = (
    SELECT apps_cryptaid.apps AS apps, cryptaid_yandexuid.yandexuid AS yandexuid
    FROM $apps_cryptaid AS apps_cryptaid
    INNER JOIN `{cryptaid_yandexuid}` AS cryptaid_yandexuid
    USING(crypta_id)
);

INSERT INTO `{apps_yandexuid}` WITH TRUNCATE
SELECT
    yandexuid,
    ListUniq(ListFlatMap(AGGREGATE_LIST(apps), ($x) -> {{RETURN ListFlatMap($x, String::ToLower);}})) AS apps
FROM $apps_yandexuid_raw
GROUP BY yandexuid
ORDER BY yandexuid;
"""


def get_recent_date_tables(sessions_tables_directory, number):
    logger.info('Getting {number} last tables from {directory}'.format(number=number,
                                                                       directory=sessions_tables_directory))
    yt = get_yt_client()
    tables = yt.list(sessions_tables_directory, absolute=False)
    dates = map(lambda date_string: datetime.datetime.strptime(date_string, '%Y-%m-%d').date(), tables)
    sorted_dates = list(sorted(dates, reverse=True))[:number]
    table_names = list(map(lambda date: os.path.join(sessions_tables_directory, str(date)), sorted_dates))
    return table_names


@with_context
def add_users_apps_to_session_reducer(key, rows, context):
    first_row = next(rows)
    apps_list = None
    if context.table_index == 0:
        apps_list = first_row['apps']
    else:
        yield {
            'session': first_row['session']
        }

    for row in rows:
        if not apps_list:
            yield {
                'session': row['session']
            }
        else:
            session = row['session'].split()
            for i in xrange(0, 2 * len(session) - 2, 2):
                session.insert(i + 1, random.choice(apps_list))
            yield {
                'session': ' '.join(session),
            }


def mix_hosts_with_apps(data_source_name, data_source):
    logger.info('Mixing hosts with apps')

    yt = get_yt_client()

    directory = config.BAR_DIR if data_source_name == 'bar' else config.METRICS_DIR
    sessions_path = os.path.join(directory, 'sessions')
    sessions_with_user_apps_path = os.path.join(directory, 'host_app_sessions')
    new_session_tables = []

    with yt.TempTable() as apps_yandexuid:
        logger.info('Getting yandexuid apps')

        yql_query(
            query_string=get_yandexuid_apps_query_template.format(
                app_metrica_month=config.APP_METRICA_MONTH,
                devid_cryptaid=config.DEVID_CRYPTAID_TABLE,
                cryptaid_yandexuid=config.CRYPTAID_YANDEXUID_TABLE,
                apps_yandexuid=apps_yandexuid,
            ),
            yt=yt,
            logger=logger,
        )

        logger.info('Removing old hosts with apps session')
        for table in yt.list(sessions_with_user_apps_path, absolute=True):
            yt.remove(table)

        logger.info('Adding user apps to {data_source_name} hosts sessions'.format(data_source_name=data_source_name))
        for date_str in map(lambda x: os.path.basename(x), data_source['session_tables']):
            yt.create_empty_table(
                path=os.path.join(sessions_with_user_apps_path, date_str),
                schema={
                    'session': 'string',
                },
            )

            yt.run_reduce(
                add_users_apps_to_session_reducer,
                [
                    apps_yandexuid,
                    os.path.join(sessions_path, date_str),
                ],
                os.path.join(sessions_with_user_apps_path, date_str),
                reduce_by='yandexuid',
                spec={
                    'data_size_per_job': 2 * 1024 * 1024 * 1024,
                    'auto_merge': {'mode': 'relaxed'},
                },
            )

            new_session_tables.append(os.path.join(sessions_with_user_apps_path, date_str))

        data_source['session_tables'] = new_session_tables
        return data_source


@retry(YtHttpResponseError, tries=10, delay=60, logger=logger)
def download_sessions_for_date(table, sessions_for_date_file_path):
    yt = get_yt_client()
    row_count = yt.row_count(table)
    logger.info('Downloading {count} sessions from {table} to {path}'.format(
        count=row_count,
        table=table,
        path=sessions_for_date_file_path,
    ))
    with open(sessions_for_date_file_path, 'w') as outfile:
        for row in yt.read_table(table):
            outfile.write('{session}\n'.format(session=row['session']))
    return row_count


def download_sessions(sessions_tables, local_sessions_file):
    local_directory = os.path.dirname(local_sessions_file)
    if not os.path.exists(local_directory):
        os.makedirs(local_directory)

    total_row_count = 0
    start_time = time.time()

    with open(local_sessions_file, 'w') as fsessions:
        for source_table in sessions_tables:
            table_date = os.path.basename(source_table)
            sessions_for_date_file_path = '{path}.{date}'.format(path=local_sessions_file, date=table_date)
            try:
                date_row_count = download_sessions_for_date(source_table, sessions_for_date_file_path)
                total_row_count += date_row_count

                # double FS write to avoid exceptions during yt table read like
                # Chunk cf7-23736c-1b6e0066-87c35dd4 is unavailable
                # and sessions duplication in case of such exceptions
                with open(sessions_for_date_file_path) as infile:
                    for line in infile:
                        fsessions.write(line)
            finally:
                if os.path.exists(sessions_for_date_file_path):
                    os.remove(sessions_for_date_file_path)

    logger.info('Downloading {rows} rows to {path} was finished in {seconds:.2f} seconds'.format(
        rows=total_row_count,
        path=local_sessions_file,
        seconds=(time.time() - start_time)))
    return total_row_count


def train_site2vec_app2vec(
    word2vec_bin, session_file, vocabulary_file, vectors_file, vector_size, min_count,
    cbow, sample, hs, negative, window, alpha, start_vectors_file_path, word2vec_debug
):
    import subprocess

    output_vectors_directory = os.path.dirname(vectors_file)
    if not os.path.exists(output_vectors_directory):
        os.makedirs(output_vectors_directory)

    args = [
        word2vec_bin,
        '-train', session_file,
        '-output', vectors_file,
        '-binary', '1',
        '-size', str(vector_size),
        '-min-count', str(min_count),
        '-threads', str(multiprocessing.cpu_count()),
        '-cbow', str(cbow),
        '-hs', str(hs),
        '-sample', str(sample),
        '-negative', str(negative),
        '-window', str(window),
        '-verbose', '1',
        '-alpha', str(alpha),
        '-read-vocab', vocabulary_file,
        '-start-vectors', start_vectors_file_path,
        '-debug', str(word2vec_debug),
    ]

    logger.info('Training started. Session file = {path}. Size = {size}. Arguments = {args}'.format(
        path=session_file,
        size=os.path.getsize(session_file),
        args=args))

    start = time.time()
    subprocess.check_call(args)
    logger.info('Training finished in %.1f seconds' % (time.time() - start))


def get_site_counter(record):
    for site, count in Counter(record['session'].split(' ')).iteritems():
        yield {'host_app': site, 'count': count}


def merge_site_counters(key, records):
    count = 0
    for record in records:
        count += record['count']
    yield {'host_app': key['host_app'], 'count': count}


def start_vectors_random_initialization(record, vector_size):
    # The same initialization as in origin word2vec
    record['start_vector'] = np.array(
        np.random.uniform(-0.5, 0.5, vector_size) / vector_size, dtype=np.float32).tostring()
    yield record


@with_context
class GetStartVectors(object):
    def __init__(self, vector_size, vector_column_name):
        self.vector_size = vector_size
        self.vector_column_name = vector_column_name

    def __call__(self, key, records, context):
        start_vector = None
        for record in records:
            if context.table_index == 0:
                start_vector = record[self.vector_column_name]
            else:
                if start_vector is None:
                    # The same initialization as in origin word2vec
                    record['start_vector'] = np.array(
                        np.random.uniform(-0.5, 0.5, self.vector_size) / self.vector_size, dtype=np.float32).tostring()
                else:
                    record['start_vector'] = start_vector
                yield record


def create_sessions_vocabulary_table(source_tables, destination_table, use_start_vectors,
                                     vectors_table, vector_column_name, vector_size):
    logger.info('Creating sessions vocabulary from {source} to {target}'.format(source=', '.join(source_tables),
                                                                                target=destination_table))

    assert use_start_vectors == 'True' or use_start_vectors == 'False'

    yt = get_yt_client()

    with yt.TempTable(prefix='vectors_learning_get_raw_site_counter') as raw_site_counter_table:
        yt.run_map_reduce(
            mapper=get_site_counter,
            reduce_combiner=merge_site_counters,
            reducer=merge_site_counters,
            source_table=source_tables,
            destination_table=raw_site_counter_table,
            reduce_by='host_app',
        )

        yt.create_empty_table(
            path=destination_table,
            schema={
                'count': 'uint64',
                'host_app': 'string',
                'start_vector': 'string',
            },
        )

        if use_start_vectors == 'True':
            logger.info('Getting start vectors from {source}'.format(source=vectors_table))
            yt.run_sort(raw_site_counter_table, sort_by='host_app')

            yt.run_reduce(
                GetStartVectors(vector_size=vector_size, vector_column_name=vector_column_name),
                source_table=[
                    vectors_table,
                    raw_site_counter_table,
                ],
                destination_table=destination_table,
                reduce_by='host_app',
            )
        else:
            logger.info('Initializing start vectors with random values')

            yt.run_map(
                partial(start_vectors_random_initialization, vector_size=vector_size),
                source_table=raw_site_counter_table,
                destination_table=destination_table,
            )

        yt.run_sort(destination_table, sort_by='count')


def download_sessions_vocabulary(sessions_vocabulary_table, vocabulary_file_path,
                                 start_vectors_file_path, min_site_occurance):
    yt = get_yt_client()

    logger.info('Downloading from {source} vocabulary to {vocab_path} and start vectors to {vectors_path}'.format(
        source=sessions_vocabulary_table, vocab_path=vocabulary_file_path, vectors_path=start_vectors_file_path))

    sessions_vocabulary = Counter()
    start_vectors = dict()

    with open(vocabulary_file_path, 'w') as vocab_outfile, open(start_vectors_file_path, 'w') as vectors_outfile:
        for record in yt.read_table(sessions_vocabulary_table):
            if record['count'] >= min_site_occurance:
                sessions_vocabulary[record['host_app']] = record['count']
                start_vectors[record['host_app']] = record['start_vector']

        for host_app, count in sessions_vocabulary.most_common():
            vocab_outfile.write('{host_app} {count}\n'.format(host_app=host_app, count=count))
            vectors_outfile.write('{vector}\n'.format(vector=start_vectors[host_app]))


def upload_vectors(vectors_file, yt_path):
    yt = get_yt_client()

    yt.create_empty_table(
        path=yt_path,
        schema={
            'host_app': 'string',
            'vector': 'string',
        },
    )

    with open(vectors_file) as fin:
        vocab_size, size = map(int, fin.readline().strip().split())

        def vectors_generator():
            for line in xrange(1, vocab_size):  # skiping first vector of </s>
                word = []
                while True:
                    ch = fin.read(1)
                    if ch == b' ':
                        break
                    if ch != b'\n':
                        word.append(ch)
                yield {
                    'host_app':  b''.join(word),
                    'vector': fin.read(size * 4),  # 4 bytes for float32
                }

        logger.info('site2vec app2vec vectors uploading started')
        yt.write_table(yt_path, vectors_generator())

        logger.info('site2vec app2vec vectors sorting started')
        yt.run_sort(yt_path, sort_by='host_app')


@retry(YtHttpResponseError, tries=10, delay=60, logger=logger)
def download_site2vec_app2vec(yt, site2vec_app2vec_table):
    host_app_vectors = {}
    for record in yt.read_table(site2vec_app2vec_table):
        host_app_vectors[record['host_app']] = vector_helpers.vector_row_to_features(record)
    return host_app_vectors


def make_site2vec_app2vec_matrix(site_vectors):
    matr = []
    for host_app, vec in sorted(site_vectors.items(), key=lambda tup: tup[0]):
        matr.append(vec)
    return np.array(matr)


def check_host_app_distances(yt, new_vectors_table, old_vectors_table):
    new_host_app_vectors = download_site2vec_app2vec(yt, new_vectors_table)
    old_host_app_vectors = download_site2vec_app2vec(yt, old_vectors_table)

    for host_app in new_host_app_vectors.keys():
        if host_app not in old_host_app_vectors:
            new_host_app_vectors.pop(host_app)

    for host_app in old_host_app_vectors.keys():
        if host_app not in new_host_app_vectors:
            old_host_app_vectors.pop(host_app)

    new_matr = make_site2vec_app2vec_matrix(new_host_app_vectors)
    old_matr = make_site2vec_app2vec_matrix(old_host_app_vectors)

    distances_vec = 1 - np.sum(new_matr * old_matr, axis=1)
    return 1.0 * len(distances_vec[
        np.where(distances_vec <= GOOD_DISTANCE_BETWEEN_NEW_AND_OLD_VECTORS)
    ]) / len(distances_vec)


@with_context
def divide_host_app_vectors(key, rows, context):
    app2vec_site2vec_table_index = 0
    app2vec_table_index = 1
    site2vec_table_index = 2
    app_count = host_count = -1

    for row in rows:
        if context.table_index == 0:
            app_count = row['count']
        elif context.table_index == 1:
            host_count = row['count']
        elif app_count >= host_count and app_count > -1:
            yield create_table_switch(app2vec_site2vec_table_index)
            yield {
                'host_app': key['host_app'],
                'vector': row['vector'],
                'type': 'app',
            }

            yield create_table_switch(app2vec_table_index)
            yield {
                'app': key['host_app'],
                'vector': row['vector'],
            }
        elif host_count > app_count and host_count > -1:
            yield create_table_switch(app2vec_site2vec_table_index)
            yield {
                'host_app': key['host_app'],
                'vector': row['vector'],
                'type': 'host',
            }

            yield create_table_switch(site2vec_table_index)
            yield {
                'host': key['host_app'],
                'vector': row['vector'],
            }


def rename_mapper(row, old_field_name, new_field_name):
    if old_field_name != new_field_name:
        row[new_field_name] = row[old_field_name]
        del row[old_field_name]
    yield row


def replace_old_vectors_with_new_ones(bar_vectors_yt_path):
    yt = get_yt_client()

    logger.info('Start replacing old vectors with new ones')

    with yt.TempTable() as renamed_app_idf, yt.TempTable() as renamed_bar_idf:
        logger.info('Checking that new vectors are not so far from old ones')

        percentage_of_good_host_apps = check_host_app_distances(
            yt=yt,
            new_vectors_table=bar_vectors_yt_path,
            old_vectors_table=config.SITE2VEC_APP2VEC_VECTORS_TABLE,
        )

        if percentage_of_good_host_apps >= ACCEPTABLE_RATE_OF_GOOD_HOSTS:
            logger.info('The verification is OK, the percentage of good host_apps is {score}'.format(
                score=percentage_of_good_host_apps))

            yt.run_map(
                partial(rename_mapper, old_field_name='app', new_field_name='host_app'),
                config.APP_IDF_TABLE,
                renamed_app_idf,
            )

            yt.run_sort(renamed_app_idf, sort_by='host_app')

            yt.run_map(
                partial(rename_mapper, old_field_name='host', new_field_name='host_app'),
                config.YANDEXUID_BAR_IDF_TABLE,
                renamed_bar_idf,
            )

            yt.run_sort(renamed_bar_idf, sort_by='host_app')

            yt.create_empty_table(
                path=config.SITE2VEC_APP2VEC_VECTORS_TABLE,
                schema={
                    'host_app': 'string',
                    'vector': 'string',
                    'type': 'string',
                },
            )

            yt.create_empty_table(
                path=config.APP2VEC_VECTORS_TABLE,
                schema={
                    'app': 'string',
                    'vector': 'string',
                },
                compression=None,
                erasure=False,
            )

            yt.create_empty_table(
                path=config.SITE2VEC_VECTORS_TABLE,
                schema={
                    'host': 'string',
                    'vector': 'string',
                },
                compression=None,
                erasure=False,
            )

            yt.run_reduce(
                divide_host_app_vectors,
                [
                    renamed_app_idf,
                    renamed_bar_idf,
                    bar_vectors_yt_path,

                ],
                [
                    config.SITE2VEC_APP2VEC_VECTORS_TABLE,
                    config.APP2VEC_VECTORS_TABLE,
                    config.SITE2VEC_VECTORS_TABLE,
                ],
                reduce_by='host_app',
            )

            today = datetime.date.today()
            date_to_set = str(datetime.date(today.year, today.month, 15))

            yt.run_sort(config.SITE2VEC_APP2VEC_VECTORS_TABLE, sort_by='host_app')
            yt.set_attribute(config.SITE2VEC_APP2VEC_VECTORS_TABLE, 'generate_date', date_to_set)
            yt.copy(
                config.SITE2VEC_APP2VEC_VECTORS_TABLE,
                os.path.join(config.SITE2VEC_APP2VEC_VECTORS_FOLDER, date_to_set),
                force=True,
            )

            yt.run_sort(config.APP2VEC_VECTORS_TABLE, sort_by='app')
            yt.set_attribute(config.APP2VEC_VECTORS_TABLE, 'generate_date', date_to_set)

            yt.run_sort(config.SITE2VEC_VECTORS_TABLE, sort_by='host')
            yt.set_attribute(config.SITE2VEC_VECTORS_TABLE, 'generate_date', date_to_set)
        else:
            logger.info('The verification has FAILED, the percentage of good host_apps is {score}'.format(
                score=percentage_of_good_host_apps))
            raise ValueError


def get_session_tables(data_sources):
    for data_source_name, data_source in data_sources.iteritems():
        if isinstance(data_source, dict):
            tables = get_recent_date_tables(data_source['sessions_yt_directory'], data_source['last_n_days'])
            data_sources[data_source_name]['session_tables'] = tables
    return data_sources


def download(data_sources):
    for data_source_name, data_source in data_sources.iteritems():
        if isinstance(data_source, dict):
            assert data_sources['with_apps'] == 'True' or data_sources['with_apps'] == 'False'
            if data_sources['with_apps'] == 'True':
                data_source.update(mix_hosts_with_apps(data_source_name, data_source))

            download_sessions(
                sessions_tables=data_source['session_tables'],
                local_sessions_file=data_source['sessions_file'],
            )


def train(data_sources, not_nirvana_launch=True):
    for data_source_name, data_source in data_sources.iteritems():
        if isinstance(data_source, dict):
            create_sessions_vocabulary_table(
                source_tables=data_source['session_tables'],
                destination_table=data_source['site_counter_table'],
                use_start_vectors=data_source['use_start_vectors'],
                vectors_table=config.SITE2VEC_APP2VEC_VECTORS_TABLE if data_sources['with_apps'] == 'True' else config.SITE2VEC_VECTORS_TABLE,
                vector_column_name='vector',
                vector_size=data_source['vector_size'],
            )
            download_sessions_vocabulary(
                sessions_vocabulary_table=data_source['site_counter_table'],
                vocabulary_file_path=data_source['vocabulary_file'],
                start_vectors_file_path=data_source['start_vectors_file'],
                min_site_occurance=data_source['min_site_size'],
            )

            if not_nirvana_launch:
                train_site2vec_app2vec(
                    word2vec_bin=data_sources['word2vec'],
                    session_file=data_source['sessions_file'],
                    vocabulary_file=data_source['vocabulary_file'],
                    vectors_file=data_source['vectors_file'],
                    vector_size=data_source['vector_size'],
                    min_count=data_source['min_site_size'],
                    cbow=data_source['cbow'],
                    sample=data_source['sample'],
                    hs=data_source['hs'],
                    negative=data_source['negative'],
                    window=data_source['window'],
                    alpha=data_source['alpha'],
                    start_vectors_file_path=data_source['start_vectors_file'],
                    word2vec_debug=data_sources['word2vec_debug'],
                )


def upload(data_sources):
    for data_source_name, data_source in data_sources.iteritems():
        if isinstance(data_source, dict):
            upload_vectors(
                vectors_file=data_source['vectors_file'],
                yt_path=data_source['vectors_yt_path'],
            )

    if data_sources['replace_old_vectors'] == 'True':
        replace_old_vectors_with_new_ones(data_sources['bar']['vectors_yt_path'])


def download_train_upload(data_sources):
    download(data_sources)
    train(data_sources)
    upload(data_sources)


def main():
    arguments = docopt(__doc__)

    data_sources = {
        'metrics': {
            'sessions_yt_directory': config.METRICS_SESSIONS_DIRECTORY,
            'sessions_file': os.path.join(config.SESSIONS_LOCAL_DIRECTORY, 'metrics'),
            'vocabulary_file': os.path.join(config.VOCABULARY_LOCAL_DIRECTORY, 'metrics'),
            'start_vectors_file': os.path.join(config.VECTORS_LOCAL_DIRECTORY, 'start_metrics.bin'),
            'vectors_file': os.path.join(config.VECTORS_LOCAL_DIRECTORY, 'metrics.bin'),
            'vectors_yt_path': config.METRICS_VECTORS_TABLE,
            'site_counter_table': config.METRICS_SITE_COUNTER_TABLE,
            'last_n_days': int(arguments['--last-n-days']),
            'min_site_size': 6 * int(arguments['--min-site-occurrence-count']),
            'use_start_vectors': arguments['--use-start-vectors'],
            'vector_size': int(arguments['--vector-size']),
            'cbow': int(arguments['--cbow']),
            'sample': float(arguments['--sample']),
            'hs': int(arguments['--hs']),
            'negative': int(arguments['--negative']),
            'window': int(arguments['--window']),
            'alpha': float(arguments['--alpha']),
        },
        'bar': {
            'sessions_yt_directory': config.BAR_SESSIONS_DIRECTORY,
            'sessions_file': os.path.join(config.SESSIONS_LOCAL_DIRECTORY, 'bar'),
            'vocabulary_file': os.path.join(config.VOCABULARY_LOCAL_DIRECTORY, 'bar'),
            'start_vectors_file': os.path.join(config.VECTORS_LOCAL_DIRECTORY, 'start_bar.bin'),
            'vectors_file': os.path.join(config.VECTORS_LOCAL_DIRECTORY, 'bar.bin'),
            'vectors_yt_path': config.BAR_VECTORS_TABLE,
            'site_counter_table': config.BAR_SITE_COUNTER_TABLE,
            'last_n_days': int(arguments['--last-n-days']),
            'min_site_size': int(arguments['--min-site-occurrence-count']),
            'use_start_vectors': arguments['--use-start-vectors'],
            'vector_size': int(arguments['--vector-size']),
            'cbow': int(arguments['--cbow']),
            'sample': float(arguments['--sample']),
            'hs': int(arguments['--hs']),
            'negative': int(arguments['--negative']),
            'window': int(arguments['--window']),
            'alpha': float(arguments['--alpha']),
        },
        'with_apps': int(arguments['--with_apps']),
        'word2vec': config.LOCAL_WORD2VEC_BIN,
        'word2vec_debug': int(arguments['--word2vec_debug']),
        'replace_old_vectors': arguments['--replace-old-vectors'],
    }

    data_sources.update(get_session_tables(data_sources))

    if arguments['download']:
        download(data_sources)
    elif arguments['train']:
        train(data_sources)
    elif arguments['upload']:
        upload(data_sources)
    elif arguments['download_train_upload']:
        download_train_upload(data_sources)


if __name__ == "__main__":
    main()
