# -*- coding: utf-8 -*-

import json
import os
import logging

from sandbox import sdk2

from enum import Enum


class SchemaMode(Enum):
    WEEK_SCHEMA = 'weak'
    STRONG_SCHEMA = 'strong'
    YAMR_SCHEMA = 'yamr'

    def __str__(self):
        return self.value


yamr_schema = [
    {"name": "key",    "required": False, "type": "any", "sort_order": "ascending", },
    {"name": "subkey", "required": False, "type": "any", "sort_order": "ascending", },
    {"name": "value",  "required": False, "type": "any", },
]


def get_default_format():
    import yt.wrapper as yt

    default_format = yt.YsonFormat(attributes={"format": "text"})
    return default_format


def get_all_attrs(yt_client, yt_path, schema_mode=SchemaMode.STRONG_SCHEMA):
    att_dict = {}
    att_names = yt_client.get_attribute(yt_path, 'user_attribute_keys')
    if schema_mode == SchemaMode.STRONG_SCHEMA:
        att_names.append('schema')
    for att_name in att_names:
        att_value = yt_client.get_attribute(yt_path, att_name)
        att_dict[att_name] = att_value
    return att_dict


def create_attr_file(yt_client, yt_path, filepath, schema_mode=SchemaMode.STRONG_SCHEMA):
    att_dict = get_all_attrs(yt_client, yt_path, schema_mode)
    if schema_mode == SchemaMode.YAMR_SCHEMA:
        att_dict.update({'schema': yamr_schema})
    with open('{}.attrs'.format(filepath), 'wb') as fobj:
        fobj.write(json.dumps(att_dict))


def unpack_table(yt_client, yt_path, filepath, format=None, with_attrs=True, schema_mode=SchemaMode.STRONG_SCHEMA):
    format = format or get_default_format()

    if with_attrs:
        with open('{}.attrs'.format(filepath), 'rb') as stream:
            table_attrs = json.loads(stream.readline())

        if schema_mode == SchemaMode.WEEK_SCHEMA:
            table_attrs.pop('schema', None)
        elif schema_mode == SchemaMode.YAMR_SCHEMA:
            table_attrs.update({'schema': yamr_schema})

        yt_client.create('table', yt_path, attributes=table_attrs)
    else:
        yt_client.create('table', yt_path)

    logging.debug(yt_path)
    logging.debug(filepath)

    with open(filepath, 'rb') as stream:
        yt_client.write_table(yt_path, stream, format=format, raw=True)


def unpack_links(yt_client, links, dst_yt_path, src_root_path):
    import yt.wrapper as yt

    for file_path in links:
        with open(file_path, 'r') as f:
            relative_target_path = f.readline()
            target_path = yt.ypath_join(dst_yt_path, relative_target_path)
        link_path = yt.ypath_join(dst_yt_path, file_path[len(src_root_path) + 1:-5])

        yt_client.link(target_path, link_path)


def unpack_tables_to_yt(yt_client, dst_yt_path, src_tables_path, format=None, with_attrs=True, schema_mode=SchemaMode.STRONG_SCHEMA):
    import yt.wrapper as yt

    format = format or get_default_format()

    # debug purposes
    if 'yt' in vars():
        logging.info('YT bindings path: ' + vars()['yt'].__file__)
    else:
        logging.info('YT bindings not found!')

    if 'yt_yson_bindings' in vars():
        logging.info('YSON bindings path: ' + vars()['yt_yson_bindings'].__file__)
    else:
        logging.info('YSON bindings not found!')

    links = []

    path_offset = len(str(src_tables_path)) + 1
    for dir_path, _, filenames in os.walk(str(src_tables_path)):
        dir_yt_path = yt.ypath_join(dst_yt_path, dir_path[path_offset:])
        if dir_yt_path.endswith('/'):
            dir_yt_path = dir_yt_path[:-1]  # Why is this slash there?!
        if not yt_client.exists(dir_yt_path):
            yt_client.create('map_node', dir_yt_path, recursive=True)

        for filepath in filenames:
            full_filepath = os.path.join(dir_path, filepath)

            logging.info('Filepath ' + filepath)

            if filepath.endswith('attrs'):
                logging.info('Attrs')
                continue

            if filepath.endswith('.link'):
                logging.info('Link')
                links.append(full_filepath)
                continue

            full_yt_path = yt.ypath_join(dir_yt_path, filepath)
            logging.info('Unpack file ' + str(full_filepath) + ' to table ' + str(full_yt_path) + ' with format ' + str(format))
            unpack_table(yt_client, full_yt_path, full_filepath, format, with_attrs, schema_mode)

    unpack_links(yt_client, links, dst_yt_path, str(src_tables_path))


def upload_files_to_yt(yt_client, dst_yt_path, src_files_path):
    import yt.wrapper as yt

    path_offset = len(str(src_files_path)) + 1
    for dir_path, _, filenames in os.walk(str(src_files_path)):
        dir_yt_path = yt.ypath_join(dst_yt_path, dir_path[path_offset:])
        if dir_yt_path.endswith('/'):
            dir_yt_path = dir_yt_path[:-1]  # Why is this slash there?!
        if not yt_client.exists(dir_yt_path):
            yt_client.create('map_node', dir_yt_path, recursive=True)

        for filepath in filenames:
            full_filepath = os.path.join(dir_path, filepath)

            if filepath.endswith('attrs'):
                continue

            if filepath.endswith('.link'):
                continue

            full_yt_path = yt.ypath_join(dir_yt_path, filepath)
            upload_file_to_yt(yt_client, full_yt_path, full_filepath)


def unpack_resource_to_yt(yt_client, resource, format=None, with_attrs=True, schema_mode=SchemaMode.STRONG_SCHEMA):
    dst_yt_path = resource.default_path
    src_tables_path = sdk2.ResourceData(resource).path
    unpack_tables_to_yt(yt_client, dst_yt_path, src_tables_path, format, with_attrs, schema_mode=SchemaMode.STRONG_SCHEMA)


def pack_table(yt_client, yt_path, dst_tables_path, format, with_attrs=True, schema_mode=SchemaMode.STRONG_SCHEMA):
    with open(dst_tables_path, 'wb') as fobj:
        for chunk in yt_client.read_table(yt_path, format=format, raw=True):
            fobj.write(chunk)
    if with_attrs:
        create_attr_file(yt_client, yt_path, dst_tables_path, schema_mode=schema_mode)


def _pack_tables_from_yt(yt_client, yt_path, dst_tables_path, links, format=None, with_attrs=True, schema_mode=SchemaMode.STRONG_SCHEMA):
    import yt.wrapper as yt

    format = format or get_default_format()

    type = yt_client.get_attribute(yt_path + '&', 'type')

    if type == 'link':
        links.append(yt_path)
    elif type == 'table':
        pack_table(yt_client, yt_path, dst_tables_path, format, with_attrs, schema_mode)
    elif type == 'map_node':
        if not os.path.exists(dst_tables_path):
            os.makedirs(dst_tables_path)

        content = yt_client.list(yt_path, absolute=False)
        for node in content:
            full_yt_path = yt.ypath_join(yt_path, str(node))
            full_path = os.path.join(dst_tables_path, str(node))
            _pack_tables_from_yt(yt_client, full_yt_path, full_path, links, format, with_attrs, schema_mode)
    else:
        # Let's don't do anything
        pass


def pack_links(yt_client, links, yt_root_path, dst_root_path):
    for link_path in links:
        target_path = yt_client.get_attribute(link_path, 'path')
        if target_path.startswith(yt_root_path):
            relative_target_path = target_path[len(yt_root_path) + 1:]
            relative_link_path = link_path[len(yt_root_path) + 1:]
            if not yt_client.exists(target_path):
                logging.warn('There is link {} on non-existing object {}'.format(link_path, target_path))
            dst_path = os.path.join(dst_root_path, relative_link_path + '.link')
            with open(dst_path, 'w') as f:
                f.write(relative_target_path)
        else:
            logging.warn('There was link {} on table {} outside of saved path! Link was skipped'.format(link_path, target_path))


def pack_tables_from_yt(yt_client, yt_path, dst_tables_path, format=None, with_attrs=True, schema_mode=SchemaMode.STRONG_SCHEMA):
    assert yt_path.startswith('//'), 'Sorry, it\'s hard to work with relative paths!'

    format = format or get_default_format()

    type = yt_client.get_attribute(yt_path, 'type')
    if type == 'table':
        table_name = yt_path.split('/')[-1]
        dst_tables_path = os.path.join(dst_tables_path, table_name)

    links = []

    _pack_tables_from_yt(yt_client, yt_path, dst_tables_path, links, format, with_attrs, schema_mode)
    pack_links(yt_client, links, yt_path, dst_tables_path)


def _traverse_dfs(yt_client, yt_path):
    import yt.wrapper as yt

    content_json = yt_client.list(yt_path, attributes=['type'], format='json')
    logging.info('Content of directory {}: {}'.format(yt_path, content_json))
    content = json.loads(content_json)
    for node in content:
        if node['$attributes']['type'] == 'map_node':
            node_path = node['$value']
            node_full_path = yt.ypath_join(yt_path, node_path)
            _traverse_dfs(yt_client, node_full_path)


def _get_parent_folder(yt_path):
    return '/'.join(yt_path.split('/')[:-1])


def upload_file_to_yt(yt_client, yt_path, filepath):
    import yt.wrapper as yt
    with yt.Transaction(client=yt_client):
        parent_folder = _get_parent_folder(yt_path)
        if not yt_client.exists(parent_folder):
            yt_client.create('map_node', parent_folder, recursive=True)

        with open(filepath) as f:
            yt_client.write_file(yt_path, f)


def traverse_yt_path(yt_client, yt_path):
    """
        Print all content of directory and subdirectories.
        For debug purposes.
    """
    _traverse_dfs(yt_client, yt_path)
