#!/usr/bin/env python
# -*- coding: utf8 -*-

import logging

from yql.api.v1.client import YqlClient

import util as u
from itertools import izip
import json

BATCH_SIZE = 20000

logging.basicConfig(
    level=logging.INFO,
    format='[%(levelname)s %(asctime)s]: %(message)s',
    datefmt='%m-%d-%Y %I:%M:%S'
)


F_MAPPING = {
    "string": str,
    "int": long,
    "uint64": long,
    "int64": long,
    "boolean": bool
}


def get_fields_type_dict(cluster, token, table):
    import yt.wrapper as yt

    field_type = {}
    try:
        yt.config["proxy"]["url"] = cluster
        yt.config["token"] = token

        schema = yt.get(table + "/@schema")

        field_type = {}
        for item in schema:
            field_type[item["name"]] = item["type"]
    except:
        pass
    return field_type


def get_column2function(column_names, fields_type_dict):
    result = list()
    for column_name in column_names:
        fields_type = fields_type_dict.get(column_name)
        type_function = F_MAPPING[fields_type]
        result += [(column_name, type_function)]
    return result


def apply_transformation(in_param):
    (x, y) = in_param
    (name, f) = x
    try:
        return (name, f("0") if y is None else f(y))
    except:
        logging.info("Error")
        logging.info(y)
        logging.info(in_param)
        return (name, f("0"))


def prepare_tale(ytc, yt_table_name, yt_table_def):
    if not ytc.exists(yt_table_name):
        u.yt_create_table(ytc, yt_table_name, yt_table_def, table_def_format="yson")


if __name__ == "__main__":

    with open("params") as f:
        params = json.load(f)
    logging.info(params)
    logging.info(params["yt_cluster"])

    yt_cluster = params["yt_cluster"]
    yt_token = params["yt_token"]
    yql_token = params["yql_token"]
    yql_request = params["yql_request"]
    date_time = params["date_time"]
    yt_table_def = params["yt_table_def"].encode("utf-8")

    yt_result_table_name = params["yt_result_table_name"]

    logging.info(yt_table_def)

    logging.info(yql_request)

    yql_query = yql_request.replace("<date_time>", date_time)

    logging.info("yql_query")
    logging.info(yql_query)

    logging.info("create YT client")
    ytc = u.yt_connect(yt_cluster, yt_token)

    logging.info("calling prepare_tale")
    prepare_tale(ytc, yt_result_table_name, yt_table_def)

    logging.info("making fields_type_dict")
    fields_type_dict = get_fields_type_dict(yt_cluster, yt_token, yt_result_table_name)
    logging.info(fields_type_dict)

    logging.info("create yql client")
    client = YqlClient(
        db=yt_cluster,
        token=yql_token,
    )

    request = client.query(yql_query)
    request.run()

    if not request.get_results().is_success:
        error_description = '\n'.join([str(err) for err in request.get_results().errors])
        logging.error(error_description)
        raise RuntimeError(error_description)

    column_names = request.get_results().table.column_names
    logging.info(column_names)

    column2func = get_column2function(column_names, fields_type_dict)

    rows = []
    for row in request.get_results().table.get_iterator():
        rows.append(dict(map(apply_transformation, izip(column2func, row))))
        if len(rows) == BATCH_SIZE:
            logging.info("write batch")
            ytc.insert_rows(
                yt_result_table_name,
                rows,
                update=True
            )
            logging.info("finish writing batch")
            rows = []

    ytc.insert_rows(
        yt_result_table_name,
        rows,
        update=True
    )
