import argparse
import json

import itertools
import logging

import yt.wrapper as yt

from crypta.dmp.common.data.python import (
    meta,
    segment_fields,
)
from crypta.dmp.crypta.services.upload_meta.proto import config_pb2
from crypta.lib.python import yaml_config
from crypta.lib.python.logging import logging_helpers
from crypta.lib.python.yt import yt_helpers

logger = logging.getLogger(__name__)


def get_key_from_row(row):
    return row["parent_keyword"], row["parent_segment"]


class SegmentToIdMapping(object):
    def __init__(self, rows):
        self.mapping = {
            get_key_from_row(row): row["id"]
            for row in rows
        }
        self.id_gen = itertools.count(max(self.mapping.values()) + 1)

    def __getitem__(self, key):
        if key not in self.mapping:
            self.mapping[key] = next(self.id_gen)

        return self.mapping[key]


def tanker_to_dmp_text(tanker_response):
    return {
        key: getattr(tanker_response, key).text
        for key in ("en", "ru")
        if getattr(tanker_response, key)
    }


def segments_to_rows(exports, segments, dmp_segments, old_rows, descriptions):
    active = set()

    logger.info("Building mapping")
    segment_to_dmp_id_mapping = SegmentToIdMapping(old_rows)

    for dmp_segment in dmp_segments:
        exports_with_hierarchies = expand_with_children(dmp_segment, segments, exports)

        for export_id, export_hierarchy in exports_with_hierarchies.items():
            export = exports[export_id]
            size = export["coverages"]["bigb"]["value"]
            timestamp = export["coverages"]["bigb"]["timestamp"]

            hierarchy = {}
            for hierarchy_export_id in export_hierarchy:
                for lang, name in json.loads(segments[hierarchy_export_id]["name"]).items():
                    hierarchy.setdefault(lang, []).append(name)

            segment = segments[export_id]

            description = descriptions[export_id].Description if export_id in descriptions else json.loads(segment["description"])
            name = json.loads(segment["name"])
            key = (export["keyword_id"], export["segment_id"])
            active.add(key)

            row = {
                segment_fields.ACL: None,
                segment_fields.DESCRIPTION: description,
                segment_fields.EXT_ID_SIZE: size,
                segment_fields.HIERARCHY: hierarchy,
                segment_fields.ID: segment_to_dmp_id_mapping[key],
                segment_fields.PARENT_KEYWORD: export["keyword_id"],
                segment_fields.PARENT_SEGMENT: export["segment_id"],
                segment_fields.STATUS: "enabled",
                segment_fields.TARIFF: 3,
                segment_fields.TIMESTAMP: timestamp,
                segment_fields.TITLE: name,
                segment_fields.YANDEXUID_SIZE: size,
            }

            logger.info("Updating row: %s", row)
            yield row

    for row in old_rows:
        key = get_key_from_row(row)
        if key not in active:
            row[segment_fields.ACL] = []
            row[segment_fields.STATUS] = "disabled"

            logger.info("Writing old row as disabled: %s", row)
            yield row


def expand_with_children(dmp_segment, segments, exports):
    all_exports = {dmp_segment.ExportId: [dmp_segment.ExportId] if dmp_segment.Depth > 0 else []}
    last_exports = [dmp_segment.ExportId]

    for i in range(dmp_segment.Depth):
        new_exports = {}
        for export_id in last_exports:
            for new_export_id in segments[export_id].get("children", []):
                if exports[new_export_id]["keyword_id"] == exports[export_id]["keyword_id"]:
                    new_exports[new_export_id] = all_exports[export_id] + ([] if i == dmp_segment.Depth - 1 else [export_id])

        last_exports = new_exports.keys()
        all_exports.update(new_exports)

    return all_exports


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", required=True)

    return parser.parse_args()


def main():
    logging_helpers.configure_stdout_logger(logging.getLogger())

    config = yaml_config.parse_config(config_pb2.TConfig, parse_args().config)

    logger.info("Creating yt client")
    yt_client = yt_helpers.get_yt_client(config.Yt.Proxy)

    with yt_client.Transaction():
        exports = {
            row["id"]: row
            for row in yt_client.read_table(config.LabExport)
        }

        raw_segments = {
            row["id"]: row
            for row in yt_client.read_table(config.LabSegments)
        }

        for segment_id, segment in raw_segments.items():
            parent_id = segment["parent_id"]
            if parent_id in raw_segments:
                raw_segments[parent_id].setdefault("children", []).extend(export["id"] for export in segment["exports"]["exports"])

        segments = {
            export["id"]: row
            for row in raw_segments.values()
            for export in row["exports"]["exports"]
        }

        logger.info("Reading old data")
        old_rows = list(yt_client.read_table(config.DmpMetaPath))

        logger.info("Updating meta")
        yt_client.write_table(
            yt.TablePath(config.DmpMetaPath, schema=meta.get_schema_internal()),
            sorted(
                segments_to_rows(exports, segments, config.DmpSegments, old_rows, config.CustomDescriptions),
                key=lambda row: row[segment_fields.ID],
            )
        )
