# -*- coding: utf-8 -*-
from datetime import datetime
import calendar

from yt.wrapper import ypath_join
import yt.wrapper as yt_wrapper

from datacloud.dev_utils.yt.yt_utils import create_folders, get_yt_client
from datacloud.dev_utils.yt.yt_ops import compress_table
from datacloud.dev_utils.yt import features
from datacloud.dev_utils.time.patterns import FMT_DATE
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.id_value.id_value_lib import encode_hexhash_as_uint64, encode_as_uint64
from datacloud.dev_utils.time.utils import now_str
from datacloud.features.geo.build_config import GeoBuildConfig
from datacloud.features.geo.helpers import PointsToGeohashReducerBase
from datacloud.dev_utils.status_db.task import Task, Status

logger = get_basic_logger(__name__)

INTERESTING_CRYPTA_SUFFIX = '_interesting_crypta'


class GeoLogsMapper:
    def __init__(self, timestamp):
        self.timestamp = timestamp

    def __call__(self, rec):
        for res in rec['result']['cluster_list']:
            center = res['center']
            yield {
                'yuid': rec['yandexuid'],
                'lon': center['lon'],
                'lat': center['lat'],
                'timestamp': self.timestamp
            }


@yt_wrapper.with_context
def calc_geo_by_cid(_, recs, context):
    points = []
    for rec in recs:
        if context.table_index == 0:
            points.append((rec['lon'], rec['lat']))
        elif len(points) > 0:
            yield {
                'cid': rec['cid'],
                'points': points
            }
        else:
            break


@yt_wrapper.with_context
def filter_cids(_, recs, context):
    interesting_rec = False
    for rec in recs:
        if context.table_index == 0:
            interesting_rec = True
        elif interesting_rec:
            yield rec
        else:
            break


@yt_wrapper.with_context
class PointsToGeohashReducerProd(PointsToGeohashReducerBase):
    def __call__(self, key, recs, context):
        prefixes = set()
        geohashes = []

        for rec in recs:
            for point in rec['points']:
                point_check_result = self.check_point(point, geohashes, prefixes)
                if point_check_result.use_point:
                    geohashes.append(point_check_result.geohash)
                    prefixes.add(point_check_result.prefix)

        yield {
            'cid': key['cid'],
            'points': ' '.join(geohashes[:self.max_points]),
        }


@yt_wrapper.with_context
def reduce_interesting_crypta(key, recs, context):
    interesting_rec = False
    hashed_cid = encode_as_uint64(key['cid'])

    for rec in recs:
        if context.table_index == 0:
            interesting_rec = True
        elif not interesting_rec:
            break
        elif rec['id_type'] in ('phone_md5', 'email_md5') and rec['id_value']:
            yield {
                'hashed_id': encode_hexhash_as_uint64(rec['id_value']),
                'hashed_cid': hashed_cid
            }

    if interesting_rec and key['cid']:
        yield {
            'hashed_id': encode_as_uint64(key['cid']),
            'hashed_cid': hashed_cid
        }


def hash_cids_mapper(rec):
    rec['hashed_cid'] = encode_as_uint64(rec.pop('cid'))
    yield rec


def assert_log_name(log_name):
    assert '/' not in log_name, 'Waiting for pure table name, got {}'.format(log_name)


class GeoLogsAggregator:
    EXTERNAL_LOGS_DIR = GeoBuildConfig.EXTERNAL_LOGS_DIR
    LOCAL_LOGS_DIR = GeoBuildConfig.LOCAL_LOGS_DIR

    TMP_FOLDER = '//tmp'
    PRODUCTION_PATH = '//home/x-products/production'
    AGGREGATES_ROOT = ypath_join(PRODUCTION_PATH, 'datacloud', 'aggregates')

    CRYPTA_DB = ypath_join(PRODUCTION_PATH, 'crypta_v2', 'crypta_db_last')
    PROD_AGGREGATES = ypath_join(AGGREGATES_ROOT, 'geo')
    USER2CLUST = ypath_join(AGGREGATES_ROOT, 'cluster', 'user2clust')
    DSSM_READY = ypath_join(AGGREGATES_ROOT, 'dssm', 'ready', 'features')

    LOCAL_LOG_SCHEMA = [
        {'name': 'yuid', 'type': 'string'},
        {'name': 'lon', 'type': 'double'},
        {'name': 'lat', 'type': 'double'},
        {'name': 'timestamp', 'type': 'int64'},
    ]

    PROD_AGGREGATE_SCHEMA = [
        {'name': 'hashed_cid', 'type': 'uint64'},
        {'name': 'points', 'type': 'string'},
    ]

    INTERESTING_CRYPTA_SCHEMA = [
        {'name': 'hashed_id', 'type': 'uint64'},
        {'name': 'hashed_cid', 'type': 'uint64'},
    ]

    def __init__(self, yt_client, tag='GEO LOGS', use_cloud_nodes=False):
        self.yt_client = yt_client
        create_folders((self.LOCAL_LOGS_DIR, ), self.yt_client)

        self.use_cloud_nodes = use_cloud_nodes
        self.tag = tag

    def get_tables_to_aggregate(self):
        ext_tables = self.yt_client.list(self.EXTERNAL_LOGS_DIR)
        local_tables = self.yt_client.list(self.LOCAL_LOGS_DIR)

        return sorted(set(ext_tables) - set(local_tables))

    def get_local_log_table(self, log_name):
        assert_log_name(log_name)
        return self.yt_client.TablePath(
            ypath_join(self.LOCAL_LOGS_DIR, log_name),
            schema=self.LOCAL_LOG_SCHEMA
        )

    def get_prod_aggregate_table(self, log_name):
        assert_log_name(log_name)
        return self.yt_client.TablePath(
            ypath_join(self.PROD_AGGREGATES, log_name),
            attributes={
                'schema': self.PROD_AGGREGATE_SCHEMA,
                'compression_codec': 'brotli_3',
                'optimize_for': 'scan'
            }
        )

    def get_interesting_crypta_table(self, log_name):
        assert_log_name(log_name)
        return self.yt_client.TablePath(
            ypath_join(self.PROD_AGGREGATES, log_name + INTERESTING_CRYPTA_SUFFIX),
            attributes={
                'schema': self.INTERESTING_CRYPTA_SCHEMA,
                'compression_codec': 'brotli_3',
                'optimize_for': 'scan'
            }
        )

    def aggregate_table(self, log_name):
        dt = datetime.strptime(log_name.split(':')[0], FMT_DATE)
        timestamp = calendar.timegm(dt.utctimetuple())

        local_log = self.get_local_log_table(log_name)
        with self.yt_client.Transaction():
            self.yt_client.run_map(
                GeoLogsMapper(timestamp),
                ypath_join(self.EXTERNAL_LOGS_DIR, log_name),
                local_log,
                spec=dict(
                    title='[{}] {}'.format(self.tag, log_name),
                    **self.cloud_nodes_spec
                )
            )

            compress_table(
                local_log,
                yt_client=self.yt_client,
                title_suffix=self.tag
            )

            self.yt_client.run_sort(
                local_log,
                sort_by='yuid',
                spec=dict(
                    title='[{}] {} / sort after'.format(self.tag, log_name),
                    **self.cloud_nodes_spec
                )
            )

    def aggregate_by_yuid(self, log_name, res_table):
        local_log = self.get_local_log_table(log_name)
        yuid_to_cid = ypath_join(self.CRYPTA_DB, 'yuid_to_cid')

        self.yt_client.run_reduce(
            calc_geo_by_cid,
            [
                local_log,
                yuid_to_cid
            ],
            res_table,
            reduce_by='yuid',
            spec=dict(
                title='[{}] Prod {} aggregate by yuid'.format(self.tag, log_name),
                **self.cloud_nodes_spec
            )
        )
        self.yt_client.run_sort(
            res_table,
            sort_by='cid',
            spec=dict(
                title='[{}] Prod {} aggregate by yuid / sort after'.format(self.tag, log_name),
                **self.cloud_nodes_spec
            )
        )

    def filter_interesting_cids(self, by_cid_table, filter_table, output_table=None):
        output_table = output_table or by_cid_table

        self.yt_client.run_reduce(
            filter_cids,
            [
                filter_table,
                by_cid_table
            ],
            output_table,
            reduce_by='cid',
            spec=dict(
                title='[{}] Filter cids'.format(self.tag),
                **self.cloud_nodes_spec
            )
        )
        self.yt_client.run_sort(
            output_table,
            sort_by='cid',
            spec=dict(
                title='[{}] Filter cids / sort after'.format(self.tag),
                **self.cloud_nodes_spec
            )
        )

    def points_to_geohash(self, points_table, output_table=None):
        output_table = output_table or points_table
        self.yt_client.run_reduce(
            PointsToGeohashReducerProd(),
            points_table,
            output_table,
            reduce_by='cid',
            spec=dict(
                title='[{}] Points to geohash'.format(self.tag),
                **self.cloud_nodes_spec
            )
        )
        self.yt_client.run_sort(
            output_table,
            sort_by='cid',
            spec=dict(
                title='[{}] Points to geohash / sort after'.format(self.tag),
                **self.cloud_nodes_spec
            )
        )

    def cid_to_idvalue(self, cids_table, output_table=None):
        output_table = output_table or cids_table

        cid_to_all = ypath_join(self.CRYPTA_DB, 'cid_to_all')
        self.yt_client.run_reduce(
            reduce_interesting_crypta,
            [
                cids_table,
                cid_to_all
            ],
            output_table,
            reduce_by='cid',
            spec=dict(
                title='[{}] Prod interesting crypta'.format(self.tag),
                **self.cloud_nodes_spec
            )
        )

    def cid_to_hashed_cid(self, input_table, output_table=None):
        output_table = output_table or input_table
        self.yt_client.run_map(
            hash_cids_mapper,
            input_table,
            output_table,
            spec=dict(
                title='[{}] Hash cids'.format(self.tag),
                **self.cloud_nodes_spec
            )
        )

    def build_prod_aggregate(self, log_name):
        with self.yt_client.Transaction(), self.yt_client.TempTable(self.TMP_FOLDER) as temp_table:
            self.aggregate_by_yuid(log_name, temp_table)

            user2clust_last = self.yt_client.list(self.USER2CLUST, absolute=True)[-1]
            self.filter_interesting_cids(temp_table, filter_table=user2clust_last)
            self.filter_interesting_cids(temp_table, filter_table=self.DSSM_READY)

            self.points_to_geohash(temp_table)
            prod_aggregate_table = self.get_prod_aggregate_table(log_name)
            self.cid_to_hashed_cid(temp_table, prod_aggregate_table)

            interesting_crypta_table = self.get_interesting_crypta_table(log_name)
            self.cid_to_idvalue(temp_table, interesting_crypta_table)

    @property
    def cloud_nodes_spec(self):
        return features.cloud_nodes_spec(self.use_cloud_nodes)


def detect_ready_geo_logs(date_time, days=None):
    geo_agg = GeoLogsAggregator(get_yt_client())
    tables_to_agg = geo_agg.get_tables_to_aggregate()

    for table_to_agg in tables_to_agg:
        logger.info(' %s ready', table_to_agg)
        yield table_to_agg, {'table_path': table_to_agg}


def build_geo_logs(task, use_cloud_nodes=False):
    logger.info(' Building %s', task)
    input_path = task.data['table_path']
    geo_agg = GeoLogsAggregator(get_yt_client(), use_cloud_nodes=use_cloud_nodes)
    geo_agg.aggregate_table(input_path)
    logger.info(' %s geo log built!', input_path)

    current_time = now_str()
    new_programs = [
        task.make_done(),
        Task('build_geo_aggregates', input_path, Status.READY, {'table_path': input_path}, current_time, current_time)
    ]
    return new_programs


def build_geo_aggregates(task, use_cloud_nodes=False):
    logger.info(' Building %s aggregates', task)
    input_path = task.data['table_path']
    geo_agg = GeoLogsAggregator(get_yt_client(), use_cloud_nodes=use_cloud_nodes)
    geo_agg.build_prod_aggregate(input_path)
    logger.info(' %s geo aggregate built!', input_path)

    current_time = now_str()
    new_tasks = [
        task.make_done(),
        Task('transfer_log_to_ydb', input_path, Status.READY, {'table_path': input_path}, current_time, current_time)
    ]
    return new_tasks
