from yt.wrapper.client import Yt
import yt
from numpy import random
import itertools
import os

if __name__ == "__main__":

    max_elements = 2*(10**6)

    client = Yt(proxy='hahn.yt.yandex.net', token=os.getenv("YT_TOKEN"))

    test_forecasted_table = "//home/webmaster/test/searchqueries/direct/forecasted_groups"

    dataset_local_file_test = open("forecasted_data_test.tsv", "w")
    dataset_local_file_train = open("forecasted_data_train.tsv", "w")

    base_header_features = ['UniqRivalsByClicks', 'UniqRivalsByShows', 'RivalsShows', 'RivalsClicks', 'RegionId', 'Position']
    header = ['target'] + base_header_features + ['query_length', 'query_words', 'query_mean_word_length', 'query_short_words']
    cd = open("column_description_file",'w')
    cd.write("0" + "\t" + "Target" + "\n")
    cd.write("5" + "\t" + "Categ" + "\n")
    cd.close()

    for line in itertools.islice(client.read_table(test_forecasted_table, raw=False, format=yt.wrapper.DsvFormat()), 0, max_elements):
        row_acc = []
        for key in base_header_features:
            row_acc.append(line[key])

        target = 0 if float(line['ForecastedClicks']) < 0.8 else
        x = line['Query']
        query_length = len(x)
        query_words = len(x.split())
        query_mean_word_length = query_length*1.0/query_words
        query_short_words = sum([1 if len(i) < 3 else 0 for i in x.split()])

        row_acc += [query_length, query_words, query_mean_word_length, query_short_words]
        row_acc = [target] + row_acc
        if random.rand() < 0.75:
            dataset_local_file_train.write("\t".join([str(x) for x in row_acc]) + "\n")
        else:
            dataset_local_file_test.write("\t".join([str(x) for x in row_acc]) + "\n")

    dataset_local_file_test.close()
    dataset_local_file_train.close()
